🔴 VLM: compile compatibility (#35724)

* llavas

* add mroe models

* fix `compile_forward` test for all models

* fix copies

* make style

* also doesn't support cache class

* fix some tests

* not copied from

* ci green?

* fix tests

* fix copies

* fix tests

* check with `numel` and remove `item`

* fix copies

* fix copies

* Update src/transformers/models/cohere2/modeling_cohere2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* opt remove cross attn

* gemma2

* fixup

* fixup

* fix newly added test

* maybe fixed?

* green please?

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2025-02-14 15:23:49 +01:00 committed by GitHub
parent b45cf0e90a
commit 0c78ef6cd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 464 additions and 1215 deletions

View File

@ -2016,6 +2016,9 @@ class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
config_class = Blip2Config
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
def __init__(self, config: Blip2Config):
super().__init__(config)

View File

@ -1284,13 +1284,13 @@ class ChameleonModel(ChameleonPreTrainedModel):
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

View File

@ -25,7 +25,7 @@ import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@ -701,7 +701,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

View File

@ -25,7 +25,7 @@ import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@ -713,7 +713,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

View File

@ -20,7 +20,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@ -550,7 +550,7 @@ class Gemma2Model(GemmaModel):
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

View File

@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
@ -161,13 +159,11 @@ class GotOcr2Config(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id

View File

@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
@ -748,89 +750,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
@ -199,13 +197,11 @@ class GotOcr2Config(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id

View File

@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
_supports_static_cache = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -129,8 +129,8 @@ class GPTNeoXJapaneseAttention(nn.Module):
cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
query = torch.cat((query, query_pass), dim=-1).contiguous()
key = torch.cat((key, key_pass), dim=-1).contiguous()
# Cache QKV values
if layer_past is not None:

View File

@ -1108,6 +1108,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
router_logits=all_router_logits,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@ -1116,13 +1117,8 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
@ -1143,7 +1139,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@ -1154,25 +1149,17 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
@ -1182,6 +1169,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask

View File

@ -1290,6 +1290,9 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
config_class = InstructBlipConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
def __init__(self, config: InstructBlipConfig):
super().__init__(config)

View File

@ -1284,6 +1284,9 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
config_class = InstructBlipVideoConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)

View File

@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@ -83,7 +81,6 @@ class LlavaConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
@ -92,7 +89,6 @@ class LlavaConfig(PretrainedConfig):
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length

View File

@ -28,6 +28,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -136,6 +137,8 @@ class LlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
@ -321,89 +324,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@ -499,14 +419,14 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
image_sizes=image_sizes,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -36,8 +36,6 @@ class LlavaNextConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@ -88,7 +86,6 @@ class LlavaNextConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
@ -99,7 +96,6 @@ class LlavaNextConfig(PretrainedConfig):
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length

View File

@ -31,6 +31,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -245,6 +246,8 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
@ -405,245 +408,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
def get_decoder(self):
return self.language_model.get_decoder()
def _merge_input_ids_with_image_features(
self,
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids=None,
labels=None,
image_token_index=None,
ignore_index=-100,
):
"""
Merge input_ids with with image features into final embeddings
Args:
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
All vision vectors of all images in the batch
feature_lens (`torch.LongTensor` of shape `(num_images)`):
The length of visual embeddings of each image as stacked in `image_features`
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with visual embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with image token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
:abels need to be recalculated to support training (if provided)
image_token_index (`int`, *optional*)
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
ignore_index (`int`, *optional*)
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
Returns:
final_embedding, final_attention_mask, position_ids, final_labels
Explanation:
each image has variable length embeddings, with length specified by feature_lens
image_features is concatenation of all visual embed vectors
task: fill each <image> with the correct number of visual embeddings
Example:
X (5 patches), Y (3 patches), Z (8)
X, Y are in the same sequence (in-context learning)
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
Edge cases:
* If tokens are same but image token sizes are different, then cannot infer left or right padding
```python
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
prompts = [
"[INST] <image>\nWhat is shown in this image? [/INST]",
"[INST] <image>\nWhat is shown in this image? [/INST]",
]
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
chart_img has 2634 tokens, while cat_img has 2340 tokens
```
input_ids: [
a b c d X g h
i j Y k l m n
]
where X is 3 tokens while Y is 5, this mean after merge
if left-padding (batched generation)
input_ids should be: [
_ _ a b c d X X X g h
i j Y Y Y Y Y k l m n
]
elif (right padding) (training)
input_ids should be: [
a b c d X X X g h _ _
i j Y Y Y Y Y k l m n
]
"""
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
if self.training and self.padding_side == "left":
logger.warning_once(
"Padding side is set to 'left' but the model is in training mode. For training "
"it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
"If that's intended, ignore this warning"
)
if not self.training and self.padding_side == "right":
logger.warning_once(
"Padding side is set to 'right' but the model is in inference mode. For correct "
"generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
"If that's intended, ignore this warning"
)
with torch.no_grad():
# ! in llava 1.6, number of patches is variable
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
if feature_lens.sum() != num_image_features:
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = self.padding_side == "left"
if batch_size > 1:
if _left_padding and _right_padding:
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
elif _right_padding and left_padding:
left_padding = False
elif _left_padding and not left_padding:
left_padding = True
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens = feature_lens.to(input_ids.device)
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
final_input_ids = torch.full(
(batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
input_ids = input_ids.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
final_labels = None
if labels is not None:
labels = labels.to(target_device)
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
with torch.no_grad():
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
max_embed_dim = max_embed_dim.to(target_device)
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -875,14 +639,14 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
image_newline=self.image_newline,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -38,8 +38,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32001):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@ -96,7 +94,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32001,
projector_hidden_act="gelu",
multimodal_projector_bias=True,
@ -116,7 +113,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
self.spatial_pool_stride = spatial_pool_stride
self.image_seq_length = image_seq_length
self.video_seq_length = video_seq_length
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias

View File

@ -32,7 +32,13 @@ from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_llava_next_video import LlavaNextVideoConfig
@ -153,6 +159,8 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
@ -440,245 +448,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
def get_decoder(self):
return self.language_model.get_decoder()
def _merge_input_ids_with_image_features(
self,
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids=None,
labels=None,
image_token_index=None,
ignore_index=-100,
):
"""
Merge input_ids with with image features into final embeddings
Args:
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
All vision vectors of all images in the batch
feature_lens (`torch.LongTensor` of shape `(num_images)`):
The length of visual embeddings of each image as stacked in `image_features`
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
Token embeddings before merging with visual embeddings
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Input_ids of tokens, possibly filled with image token
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
:abels need to be recalculated to support training (if provided)
image_token_index (`int`, *optional*)
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
ignore_index (`int`, *optional*)
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
Returns:
final_embedding, final_attention_mask, position_ids, final_labels
Explanation:
each image has variable length embeddings, with length specified by feature_lens
image_features is concatenation of all visual embed vectors
task: fill each <image> with the correct number of visual embeddings
Example:
X (5 patches), Y (3 patches), Z (8)
X, Y are in the same sequence (in-context learning)
if right padding
input_ids: [
a b c d e f X g h i j k Y l m
o p q r Z s t u v _ _ _ _ _ _
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
]
elif left padding
input_ids: [
a b c d e f X g h i j k Y l m
_ _ _ _ _ _ o p q r Z s t u v
]
input_ids should be: [
a b c d e f X X X X X g h i j k Y Y Y l m
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
]
labels should be: [
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
]
Edge cases:
* If tokens are same but image token sizes are different, then cannot infer left or right padding
```python
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
prompts = [
"[INST] <image>\nWhat is shown in this image? [/INST]",
"[INST] <image>\nWhat is shown in this image? [/INST]",
]
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
chart_img has 2634 tokens, while cat_img has 2340 tokens
```
input_ids: [
a b c d X g h
i j Y k l m n
]
where X is 3 tokens while Y is 5, this mean after merge
if left-padding (batched generation)
input_ids should be: [
_ _ a b c d X X X g h
i j Y Y Y Y Y k l m n
]
elif (right padding) (training)
input_ids should be: [
a b c d X X X g h _ _
i j Y Y Y Y Y k l m n
]
"""
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
if self.training and self.padding_side == "left":
logger.warning_once(
"Padding side is set to 'left' but the model is in training mode. For training "
"it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
"If that's intended, ignore this warning"
)
if not self.training and self.padding_side == "right":
logger.warning_once(
"Padding side is set to 'right' but the model is in inference mode. For correct "
"generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
"If that's intended, ignore this warning"
)
with torch.no_grad():
# ! in llava 1.6, number of patches is variable
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
if feature_lens.sum() != num_image_features:
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = self.padding_side == "left"
if batch_size > 1:
if _left_padding and _right_padding:
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
elif _right_padding and left_padding:
left_padding = False
elif _left_padding and not left_padding:
left_padding = True
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens = feature_lens.to(input_ids.device)
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
final_input_ids = torch.full(
(batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
input_ids = input_ids.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
final_labels = None
if labels is not None:
labels = labels.to(target_device)
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
with torch.no_grad():
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
max_embed_dim = max_embed_dim.to(target_device)
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -948,14 +717,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
image_newline=self.image_newline,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@ -970,14 +739,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

View File

@ -30,6 +30,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from ...configuration_utils import PretrainedConfig
from ...utils import (
is_torchdynamo_compiling,
logging,
)
from ..auto import CONFIG_MAPPING, AutoConfig
@ -52,8 +53,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32001):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@ -110,7 +109,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32001,
projector_hidden_act="gelu",
multimodal_projector_bias=True,
@ -130,7 +128,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
self.spatial_pool_stride = spatial_pool_stride
self.image_seq_length = image_seq_length
self.video_seq_length = video_seq_length
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
@ -479,14 +476,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
image_newline=self.image_newline,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@ -501,14 +498,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

View File

@ -30,6 +30,7 @@ from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
is_torchdynamo_compiling,
logging,
)
from ...utils.deprecation import deprecate_kwarg
@ -250,7 +251,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
@ -712,19 +713,15 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@ -741,18 +738,14 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_index).sum()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_video_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)

View File

@ -22,10 +22,10 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
AttentionMaskConverter,
)
from ...modeling_outputs import (
BaseModelOutputWithPast,
@ -98,6 +98,7 @@ class OPTAttention(nn.Module):
def __init__(
self,
config: OPTConfig,
layer_idx: int = None,
**kwargs,
):
super().__init__()
@ -106,6 +107,13 @@ class OPTAttention(nn.Module):
self.num_heads = config.num_attention_heads
self.dropout = config.attention_dropout
self.enable_bias = config.enable_bias
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.head_dim = self.embed_dim // self.num_heads
self.is_causal = True
@ -122,9 +130,6 @@ class OPTAttention(nn.Module):
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
@ -134,52 +139,33 @@ class OPTAttention(nn.Module):
output_attentions: bool = False,
# isn't needed in normal attention, but needed in flash attention so to keep the signature same
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
# save all key/value_states to cache to be re-used for fast auto-regressive generation
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
attn_weights = torch.matmul(query_states, key_states.transpose(3, 2))
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
@ -187,39 +173,19 @@ class OPTAttention(nn.Module):
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value_states)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.transpose(1, 2).contiguous()
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
return attn_output, attn_probs, past_key_value
class OptFlashAttention2(OPTAttention):
@ -245,33 +211,33 @@ class OptFlashAttention2(OPTAttention):
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, _, _ = hidden_states.size()
# get query proj
bsz, query_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
# get key, value proj
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
query_length = query_states.shape[1]
tgt_len = key_states.shape[-2]
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
# save all key/value_states to cache to be re-used for fast auto-regressive generation
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
attn_dropout = self.dropout if self.training else 0.0
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
@ -331,6 +297,7 @@ class OPTSdpaAttention(OPTAttention):
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or layer_head_mask is not None:
logger.warning_once(
@ -344,24 +311,24 @@ class OPTSdpaAttention(OPTAttention):
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
) # TODO after merge add position_ids=position_ids
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) * self.scaling
query_states = self._shape(query_states, -1, bsz)
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# get key, value proj
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states)
# shape now is (bsz, num_heads, seq_len, head_dim), all are continuous
# save all key/value_states to cache to be re-used for fast auto-regressive generation
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
causal_mask = attention_mask
if attention_mask is not None:
@ -378,10 +345,6 @@ class OPTSdpaAttention(OPTAttention):
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
# this model uses the scaling factor in the query projection for some reason, but not in Q@K^T
# so we need to scale to remove scaling in SDPA to have similar results with eager.
# Maybe needs a change in the model to remove scaling in query projection
scale=1.0,
)
attn_output = attn_output.transpose(1, 2).contiguous()
@ -399,11 +362,11 @@ OPT_ATTENTION_CLASSES = {
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
def __init__(self, config: OPTConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config)
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
@ -425,6 +388,7 @@ class OPTDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -440,6 +404,8 @@ class OPTDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence..
"""
residual = hidden_states
@ -456,6 +422,7 @@ class OPTDecoderLayer(nn.Module):
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@ -524,6 +491,9 @@ class OPTPreTrainedModel(PreTrainedModel):
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.init_std
@ -601,6 +571,10 @@ OPT_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@ -643,9 +617,7 @@ class OPTDecoder(OPTPreTrainedModel):
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@ -657,48 +629,130 @@ class OPTDecoder(OPTPreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
inputs_embeds: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Updates the causal mask for the decoder.
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
batch_size, seq_length = input_shape
mask_seq_length = past_key_values_length + seq_length
if self._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
return causal_attention_mask, attention_mask
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
if self._use_sdpa and not output_attentions and head_mask is None:
causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_attention_mask, attention_mask
return causal_mask
def forward(
self,
@ -712,6 +766,7 @@ class OPTDecoder(OPTPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
@ -764,6 +819,10 @@ class OPTDecoder(OPTPreTrainedModel):
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -773,51 +832,65 @@ class OPTDecoder(OPTPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if input_ids is not None:
input_ids = input_ids.view(-1, input_ids.shape[-1])
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if past_key_values is None:
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
"You should pass an instance of `DynamicCache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
causal_attention_mask, attention_mask = self._update_causal_mask(
inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if attention_mask is None:
seq_length = past_seen_tokens + inputs_embeds.shape[1]
attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
# embed positions
if position_ids is None:
# position_ids = cache_position.unsqueeze(0)
position_ids = torch.cumsum(attention_mask, dim=1)
position_ids = (position_ids * attention_mask - 1).long()
# cut positions if `past_key_values_length` is > 0
position_ids = position_ids[:, past_key_values_length:]
# cut positions if `past_seen_tokens` is > 0
position_ids = position_ids[:, past_seen_tokens:]
pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
@ -838,34 +911,34 @@ class OPTDecoder(OPTPreTrainedModel):
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
causal_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
position_ids,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
attention_mask=causal_mask,
position_ids=position_ids,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@ -881,6 +954,9 @@ class OPTDecoder(OPTPreTrainedModel):
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
@ -930,6 +1006,7 @@ class OPTModel(OPTPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -950,6 +1027,7 @@ class OPTModel(OPTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
@ -1008,6 +1086,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
@ -1069,6 +1148,10 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
Returns:
@ -1107,6 +1190,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
logits = self.lm_head(outputs[0]).contiguous()

View File

@ -29,6 +29,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -508,7 +509,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "

View File

@ -38,8 +38,6 @@ class VideoLlavaConfig(PretrainedConfig):
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
Defaults to `LlamaConfig` if not indicated.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
video_token_index (`int`, *optional*, defaults to 32001):
@ -88,7 +86,6 @@ class VideoLlavaConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
video_token_index=32001,
projector_hidden_act="gelu",
@ -99,7 +96,6 @@ class VideoLlavaConfig(PretrainedConfig):
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.video_token_index = video_token_index
self.projector_hidden_act = projector_hidden_act

View File

@ -28,6 +28,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -137,6 +138,8 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = (
@ -276,92 +279,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
def get_decoder(self):
return self.language_model.get_decoder()
def _merge_input_ids_with_visual_features(
self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1
):
num_images, num_image_patches, embed_dim = visual_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == special_vision_token
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != special_vision_token)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1
)
nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
# expand input ids so that the second "merge" with videos does not fail
final_embedding = torch.zeros(
batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device
)
final_input_ids = torch.full(
(batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
if labels is not None:
final_labels = torch.full(
(batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
else:
final_labels = None
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask
if image_to_overwrite.sum() != visual_features.shape[:-1].numel():
visual_type = "videos" if num_frames == 8 else "images"
num_images //= num_frames
raise ValueError(
f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while"
f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
@ -579,14 +496,14 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@ -595,14 +512,14 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0] * video_features.shape[1]
if n_video_tokens != n_video_features:
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_index).sum()
n_video_features = video_features.shape[0] * video_features.shape[1]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

View File

@ -37,8 +37,6 @@ class VipLlavaConfig(PretrainedConfig):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@ -78,7 +76,6 @@ class VipLlavaConfig(PretrainedConfig):
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
projector_layernorm_eps=1e-5,
@ -86,7 +83,6 @@ class VipLlavaConfig(PretrainedConfig):
image_seq_length=576,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.projector_layernorm_eps = projector_layernorm_eps

View File

@ -28,6 +28,7 @@ from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -137,6 +138,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
@ -297,89 +300,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
image_features = self.multi_modal_projector(image_features)
return image_features
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@ -469,14 +389,14 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -1783,12 +1783,12 @@ class GenerationTesterMixin:
model.config.use_cache = True
model.config.is_decoder = True
batch_size = input_ids.shape[0]
max_length = 30
max_new_tokens = 10
# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
generation_kwargs = {
"max_length": max_length,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
@ -1811,10 +1811,11 @@ class GenerationTesterMixin:
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
# -1 because the last generated token isn't yet in the cache.
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
self.assertIsInstance(outputs.past_key_values, StaticCache)
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
@ -2022,7 +2023,7 @@ class GenerationTesterMixin:
config.is_decoder = True
batch_size = main_input.shape[0]
seq_length = main_input.shape[-1]
seq_length = self.model_tester.seq_length
max_new_tokens = 20
for dtype in (torch.float32, torch.float16):
@ -2134,7 +2135,15 @@ class GenerationTesterMixin:
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True
else:
# force compilation (e.g. fast CI, CPU
model.generation_config.compile_config._compile_all_devices = True
generation_kwargs = {
"do_sample": False,
@ -2175,7 +2184,14 @@ class GenerationTesterMixin:
)
self.assertFalse(isinstance(decoder_cache, DynamicCache))
self.assertTrue(decoder_cache.is_compileable)
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
if "blip" in model.__class__.__name__.lower():
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
else:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
@ -2198,9 +2214,19 @@ class GenerationTesterMixin:
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
# BLIP is the only exception with custom generate which call `self.lm.generate()`
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
# compatible with multimodality
if "blip" in model.__class__.__name__.lower():
model.language_model.generation_config.compile_config._compile_all_devices = True
if not has_defined_cache_implementation:
model.language_model.generation_config.cache_implementation = "static"
else:
# force compilation (e.g. fast CI, CPU)
model.generation_config.compile_config._compile_all_devices = True
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
output_generate = model.generate(
@ -2218,8 +2244,10 @@ class GenerationTesterMixin:
**inputs_dict,
)
# Sanity check: compilation has happened
self.assertTrue(hasattr(model, "_compiled_call"))
if "blip" in model.__class__.__name__.lower():
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
else:
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)

View File

@ -286,10 +286,18 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
@unittest.skip(reason="Unsupported")
@unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_with_static_cache(self):
pass
@unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_compile_model_forward(self):
pass
@require_torch
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -816,6 +816,10 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
class Blip2TextModelTester:

View File

@ -386,10 +386,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_cpu_offload(self):
pass
@unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme
def test_custom_4d_attention_mask(self):
pass
@unittest.skip("VQ-VAE module doesn't initialize weights properly")
def test_initialization(self):
pass

View File

@ -256,12 +256,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
@unittest.skip(
reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step"
)
def test_generate_compile_1_end_to_end(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass

View File

@ -838,6 +838,14 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_with_static_cache(self):
pass
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_model_forward(self):
pass
@unittest.skip(reason="We only test the model that takes in multiple images")
def test_model(self):
pass

View File

@ -530,6 +530,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -546,6 +546,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -316,14 +316,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass

View File

@ -365,22 +365,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="CPU offload is not yet supported")
def test_cpu_offload(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@ -391,6 +375,10 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("LLaVA Next has dynamic control flow in unpadding")
def test_generate_compile_model_forward(self):
pass
@require_torch
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -382,26 +382,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="CPU offload is not yet supported")
def test_cpu_offload(self):
pass
@unittest.skip(
reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
)
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(
reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
)
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@ -412,6 +392,10 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("LLaVA Next Video has dynamic control flow in unpadding")
def test_generate_compile_model_forward(self):
pass
@require_torch
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -346,6 +346,10 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("LLaVA OneVision has dynamic control flow in unpadding")
def test_generate_compile_model_forward(self):
pass
@require_torch
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -540,7 +540,6 @@ class MT5ModelTester:
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"use_cache": False,
}
return config, inputs_dict

View File

@ -81,7 +81,7 @@ class OPTModelTester:
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_position_embeddings=50,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
@ -89,7 +89,6 @@ class OPTModelTester:
num_labels=3,
word_embed_proj_dim=16,
type_sequence_label_size=2,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -113,7 +112,6 @@ class OPTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.word_embed_proj_dim = word_embed_proj_dim
self.is_encoder_decoder = False
self.attn_implementation = attn_implementation
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
@ -143,7 +141,6 @@ class OPTModelTester:
embed_dim=self.embed_dim,
is_encoder_decoder=False,
word_embed_proj_dim=self.word_embed_proj_dim,
attn_implementation=self.attn_implementation,
)
def get_pipeline_config(self):

View File

@ -545,7 +545,6 @@ class T5ModelTester:
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"use_cache": False,
}
return config, inputs_dict

View File

@ -226,14 +226,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass

View File

@ -306,14 +306,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Compile not yet supported because it is not yet supported in LLava")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass

View File

@ -4324,10 +4324,6 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
self.skipTest(
reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
)
if config.model_type in ["paligemma"]:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
@ -4778,6 +4774,9 @@ class ModelTesterMixin:
model = model_class(config).to(device=torch_device, dtype=torch.float32)
set_model_for_less_flaky_test(model)
if "position_ids" not in inspect.signature(model.forward).parameters:
continue # this model doesn't accept position ids as input
(
input_ids,
position_ids,