mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🔴 VLM: compile compatibility (#35724)
* llavas * add mroe models * fix `compile_forward` test for all models * fix copies * make style * also doesn't support cache class * fix some tests * not copied from * ci green? * fix tests * fix copies * fix tests * check with `numel` and remove `item` * fix copies * fix copies * Update src/transformers/models/cohere2/modeling_cohere2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * opt remove cross attn * gemma2 * fixup * fixup * fix newly added test * maybe fixed? * green please? --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
b45cf0e90a
commit
0c78ef6cd3
@ -2016,6 +2016,9 @@ class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
config_class = Blip2Config
|
||||
main_input_name = "pixel_values"
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
@ -1284,13 +1284,13 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
if n_image_tokens_in_text != n_image_features:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
||||
|
@ -25,7 +25,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
@ -701,7 +701,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, HybridCache):
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
@ -25,7 +25,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
@ -713,7 +713,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, HybridCache):
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
@ -550,7 +550,7 @@ class Gemma2Model(GemmaModel):
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if isinstance(past_key_values, HybridCache):
|
||||
if isinstance(past_key_values, (HybridCache, StaticCache)):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
||||
|
@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 151859):
|
||||
The image token index to encode the image prompt.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
@ -161,13 +159,11 @@ class GotOcr2Config(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=151859,
|
||||
image_seq_length=576,
|
||||
pad_token_id=-1,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.image_seq_length = image_seq_length
|
||||
self.pad_token_id = pad_token_id
|
||||
|
@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
|
||||
@ -748,89 +750,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
image_outputs = self.vision_tower(pixel_values).last_hidden_state
|
||||
return self.multi_modal_projector(image_outputs)
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
if left_padding:
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
else:
|
||||
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
|
||||
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
|
||||
image_to_overwrite &= padding_mask
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 151859):
|
||||
The image token index to encode the image prompt.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
@ -199,13 +197,11 @@ class GotOcr2Config(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=151859,
|
||||
image_seq_length=576,
|
||||
pad_token_id=-1,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.image_seq_length = image_seq_length
|
||||
self.pad_token_id = pad_token_id
|
||||
|
@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -129,8 +129,8 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
query = torch.cat((query, query_pass), dim=-1)
|
||||
key = torch.cat((key, key_pass), dim=-1)
|
||||
query = torch.cat((query, query_pass), dim=-1).contiguous()
|
||||
key = torch.cat((key, key_pass), dim=-1).contiguous()
|
||||
|
||||
# Cache QKV values
|
||||
if layer_past is not None:
|
||||
|
@ -1108,6 +1108,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
@ -1116,13 +1117,8 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
@ -1143,7 +1139,6 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
@ -1154,25 +1149,17 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
@ -1182,6 +1169,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
@ -1290,6 +1290,9 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
|
||||
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
|
||||
config_class = InstructBlipConfig
|
||||
main_input_name = "pixel_values"
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
|
||||
def __init__(self, config: InstructBlipConfig):
|
||||
super().__init__(config)
|
||||
|
@ -1284,6 +1284,9 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||
class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
|
||||
config_class = InstructBlipVideoConfig
|
||||
main_input_name = "pixel_values"
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
|
||||
def __init__(self, config: InstructBlipVideoConfig):
|
||||
super().__init__(config)
|
||||
|
@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 32000):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
@ -83,7 +81,6 @@ class LlavaConfig(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32000,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="default",
|
||||
@ -92,7 +89,6 @@ class LlavaConfig(PretrainedConfig):
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -136,6 +137,8 @@ class LlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Llava isn't meant for training from scratch - only
|
||||
@ -321,89 +324,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
if left_padding:
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
else:
|
||||
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
|
||||
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
|
||||
image_to_overwrite &= padding_mask
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@ -499,14 +419,14 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -36,8 +36,6 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 32000):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
@ -88,7 +86,6 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32000,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="default",
|
||||
@ -99,7 +96,6 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
|
@ -31,6 +31,7 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -245,6 +246,8 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of LlavaNext isn't meant for training from scratch - only
|
||||
@ -405,245 +408,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids=None,
|
||||
labels=None,
|
||||
image_token_index=None,
|
||||
ignore_index=-100,
|
||||
):
|
||||
"""
|
||||
Merge input_ids with with image features into final embeddings
|
||||
|
||||
Args:
|
||||
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
|
||||
All vision vectors of all images in the batch
|
||||
feature_lens (`torch.LongTensor` of shape `(num_images)`):
|
||||
The length of visual embeddings of each image as stacked in `image_features`
|
||||
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
|
||||
Token embeddings before merging with visual embeddings
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Input_ids of tokens, possibly filled with image token
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
|
||||
:abels need to be recalculated to support training (if provided)
|
||||
image_token_index (`int`, *optional*)
|
||||
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
|
||||
ignore_index (`int`, *optional*)
|
||||
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
|
||||
Returns:
|
||||
final_embedding, final_attention_mask, position_ids, final_labels
|
||||
|
||||
Explanation:
|
||||
each image has variable length embeddings, with length specified by feature_lens
|
||||
image_features is concatenation of all visual embed vectors
|
||||
task: fill each <image> with the correct number of visual embeddings
|
||||
Example:
|
||||
X (5 patches), Y (3 patches), Z (8)
|
||||
X, Y are in the same sequence (in-context learning)
|
||||
if right padding
|
||||
input_ids: [
|
||||
a b c d e f X g h i j k Y l m
|
||||
o p q r Z s t u v _ _ _ _ _ _
|
||||
]
|
||||
input_ids should be: [
|
||||
a b c d e f X X X X X g h i j k Y Y Y l m
|
||||
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
|
||||
]
|
||||
labels should be: [
|
||||
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
||||
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
|
||||
]
|
||||
elif left padding
|
||||
input_ids: [
|
||||
a b c d e f X g h i j k Y l m
|
||||
_ _ _ _ _ _ o p q r Z s t u v
|
||||
]
|
||||
input_ids should be: [
|
||||
a b c d e f X X X X X g h i j k Y Y Y l m
|
||||
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
|
||||
]
|
||||
labels should be: [
|
||||
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
||||
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
|
||||
]
|
||||
Edge cases:
|
||||
* If tokens are same but image token sizes are different, then cannot infer left or right padding
|
||||
```python
|
||||
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
|
||||
prompts = [
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]",
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]",
|
||||
]
|
||||
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
|
||||
chart_img has 2634 tokens, while cat_img has 2340 tokens
|
||||
```
|
||||
|
||||
input_ids: [
|
||||
a b c d X g h
|
||||
i j Y k l m n
|
||||
]
|
||||
where X is 3 tokens while Y is 5, this mean after merge
|
||||
if left-padding (batched generation)
|
||||
input_ids should be: [
|
||||
_ _ a b c d X X X g h
|
||||
i j Y Y Y Y Y k l m n
|
||||
]
|
||||
elif (right padding) (training)
|
||||
input_ids should be: [
|
||||
a b c d X X X g h _ _
|
||||
i j Y Y Y Y Y k l m n
|
||||
]
|
||||
"""
|
||||
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
|
||||
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
|
||||
|
||||
if self.training and self.padding_side == "left":
|
||||
logger.warning_once(
|
||||
"Padding side is set to 'left' but the model is in training mode. For training "
|
||||
"it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
|
||||
"If that's intended, ignore this warning"
|
||||
)
|
||||
if not self.training and self.padding_side == "right":
|
||||
logger.warning_once(
|
||||
"Padding side is set to 'right' but the model is in inference mode. For correct "
|
||||
"generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
|
||||
"If that's intended, ignore this warning"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# ! in llava 1.6, number of patches is variable
|
||||
num_images = feature_lens.size(0)
|
||||
num_image_features, embed_dim = image_features.shape
|
||||
if feature_lens.sum() != num_image_features:
|
||||
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
|
||||
batch_size = input_ids.shape[0]
|
||||
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
||||
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
||||
|
||||
left_padding = self.padding_side == "left"
|
||||
if batch_size > 1:
|
||||
if _left_padding and _right_padding:
|
||||
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
||||
elif _right_padding and left_padding:
|
||||
left_padding = False
|
||||
elif _left_padding and not left_padding:
|
||||
left_padding = True
|
||||
|
||||
# Whether to turn off right padding
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == image_token_index
|
||||
# special_image_token_mask: [bsz, seqlen]
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# num_special_image_tokens: [bsz]
|
||||
# Reserve for padding of num_images
|
||||
total_num_special_image_tokens = torch.sum(special_image_token_mask)
|
||||
if total_num_special_image_tokens != num_images:
|
||||
raise ValueError(
|
||||
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
|
||||
)
|
||||
# Compute the maximum embed dimension
|
||||
# max_image_feature_lens is max_feature_lens per batch
|
||||
feature_lens = feature_lens.to(input_ids.device)
|
||||
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
|
||||
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
|
||||
embed_sequence_lengths = (
|
||||
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
|
||||
)
|
||||
max_embed_dim = embed_sequence_lengths.max()
|
||||
|
||||
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
# ! instead of special_image_token_mask * (num_image_patches - 1)
|
||||
# special_image_token_mask * (num_feature_len - 1)
|
||||
special_image_token_mask = special_image_token_mask.long()
|
||||
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
|
||||
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
|
||||
if left_padding:
|
||||
# shift right token positions so that they are ending at the same number
|
||||
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
|
||||
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
|
||||
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_input_ids = torch.full(
|
||||
(batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
input_ids = input_ids.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
|
||||
final_labels = None
|
||||
if labels is not None:
|
||||
labels = labels.to(target_device)
|
||||
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
with torch.no_grad():
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
|
||||
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
|
||||
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
|
||||
|
||||
if left_padding:
|
||||
# exclude padding on the left
|
||||
max_embed_dim = max_embed_dim.to(target_device)
|
||||
val = (max_embed_dim - embed_indices) <= embed_seq_lens
|
||||
else:
|
||||
# exclude padding on the right
|
||||
val = embed_indices < embed_seq_lens
|
||||
image_to_overwrite &= val
|
||||
|
||||
if image_to_overwrite.sum() != num_image_features:
|
||||
raise ValueError(
|
||||
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
|
||||
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. "
|
||||
f"This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
||||
|
||||
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
@ -875,14 +639,14 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -38,8 +38,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 32001):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
@ -96,7 +94,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32001,
|
||||
projector_hidden_act="gelu",
|
||||
multimodal_projector_bias=True,
|
||||
@ -116,7 +113,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
self.spatial_pool_stride = spatial_pool_stride
|
||||
self.image_seq_length = image_seq_length
|
||||
self.video_seq_length = video_seq_length
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
@ -32,7 +32,13 @@ from ...generation import GenerationMixin
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_llava_next_video import LlavaNextVideoConfig
|
||||
@ -153,6 +159,8 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
|
||||
@ -440,245 +448,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
image_features,
|
||||
feature_lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids=None,
|
||||
labels=None,
|
||||
image_token_index=None,
|
||||
ignore_index=-100,
|
||||
):
|
||||
"""
|
||||
Merge input_ids with with image features into final embeddings
|
||||
|
||||
Args:
|
||||
image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
|
||||
All vision vectors of all images in the batch
|
||||
feature_lens (`torch.LongTensor` of shape `(num_images)`):
|
||||
The length of visual embeddings of each image as stacked in `image_features`
|
||||
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
|
||||
Token embeddings before merging with visual embeddings
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Input_ids of tokens, possibly filled with image token
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
|
||||
:abels need to be recalculated to support training (if provided)
|
||||
image_token_index (`int`, *optional*)
|
||||
Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
|
||||
ignore_index (`int`, *optional*)
|
||||
Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
|
||||
Returns:
|
||||
final_embedding, final_attention_mask, position_ids, final_labels
|
||||
|
||||
Explanation:
|
||||
each image has variable length embeddings, with length specified by feature_lens
|
||||
image_features is concatenation of all visual embed vectors
|
||||
task: fill each <image> with the correct number of visual embeddings
|
||||
Example:
|
||||
X (5 patches), Y (3 patches), Z (8)
|
||||
X, Y are in the same sequence (in-context learning)
|
||||
if right padding
|
||||
input_ids: [
|
||||
a b c d e f X g h i j k Y l m
|
||||
o p q r Z s t u v _ _ _ _ _ _
|
||||
]
|
||||
input_ids should be: [
|
||||
a b c d e f X X X X X g h i j k Y Y Y l m
|
||||
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
|
||||
]
|
||||
labels should be: [
|
||||
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
||||
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
|
||||
]
|
||||
elif left padding
|
||||
input_ids: [
|
||||
a b c d e f X g h i j k Y l m
|
||||
_ _ _ _ _ _ o p q r Z s t u v
|
||||
]
|
||||
input_ids should be: [
|
||||
a b c d e f X X X X X g h i j k Y Y Y l m
|
||||
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
|
||||
]
|
||||
labels should be: [
|
||||
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
||||
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
|
||||
]
|
||||
Edge cases:
|
||||
* If tokens are same but image token sizes are different, then cannot infer left or right padding
|
||||
```python
|
||||
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
|
||||
prompts = [
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]",
|
||||
"[INST] <image>\nWhat is shown in this image? [/INST]",
|
||||
]
|
||||
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
|
||||
chart_img has 2634 tokens, while cat_img has 2340 tokens
|
||||
```
|
||||
|
||||
input_ids: [
|
||||
a b c d X g h
|
||||
i j Y k l m n
|
||||
]
|
||||
where X is 3 tokens while Y is 5, this mean after merge
|
||||
if left-padding (batched generation)
|
||||
input_ids should be: [
|
||||
_ _ a b c d X X X g h
|
||||
i j Y Y Y Y Y k l m n
|
||||
]
|
||||
elif (right padding) (training)
|
||||
input_ids should be: [
|
||||
a b c d X X X g h _ _
|
||||
i j Y Y Y Y Y k l m n
|
||||
]
|
||||
"""
|
||||
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
|
||||
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
|
||||
|
||||
if self.training and self.padding_side == "left":
|
||||
logger.warning_once(
|
||||
"Padding side is set to 'left' but the model is in training mode. For training "
|
||||
"it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
|
||||
"If that's intended, ignore this warning"
|
||||
)
|
||||
if not self.training and self.padding_side == "right":
|
||||
logger.warning_once(
|
||||
"Padding side is set to 'right' but the model is in inference mode. For correct "
|
||||
"generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
|
||||
"If that's intended, ignore this warning"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# ! in llava 1.6, number of patches is variable
|
||||
num_images = feature_lens.size(0)
|
||||
num_image_features, embed_dim = image_features.shape
|
||||
if feature_lens.sum() != num_image_features:
|
||||
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
|
||||
batch_size = input_ids.shape[0]
|
||||
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
||||
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
||||
|
||||
left_padding = self.padding_side == "left"
|
||||
if batch_size > 1:
|
||||
if _left_padding and _right_padding:
|
||||
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
||||
elif _right_padding and left_padding:
|
||||
left_padding = False
|
||||
elif _left_padding and not left_padding:
|
||||
left_padding = True
|
||||
|
||||
# Whether to turn off right padding
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == image_token_index
|
||||
# special_image_token_mask: [bsz, seqlen]
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# num_special_image_tokens: [bsz]
|
||||
# Reserve for padding of num_images
|
||||
total_num_special_image_tokens = torch.sum(special_image_token_mask)
|
||||
if total_num_special_image_tokens != num_images:
|
||||
raise ValueError(
|
||||
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
|
||||
)
|
||||
# Compute the maximum embed dimension
|
||||
# max_image_feature_lens is max_feature_lens per batch
|
||||
feature_lens = feature_lens.to(input_ids.device)
|
||||
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
|
||||
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
|
||||
embed_sequence_lengths = (
|
||||
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
|
||||
)
|
||||
max_embed_dim = embed_sequence_lengths.max()
|
||||
|
||||
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
# ! instead of special_image_token_mask * (num_image_patches - 1)
|
||||
# special_image_token_mask * (num_feature_len - 1)
|
||||
special_image_token_mask = special_image_token_mask.long()
|
||||
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
|
||||
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
|
||||
if left_padding:
|
||||
# shift right token positions so that they are ending at the same number
|
||||
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
|
||||
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
|
||||
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_input_ids = torch.full(
|
||||
(batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
input_ids = input_ids.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
|
||||
final_labels = None
|
||||
if labels is not None:
|
||||
labels = labels.to(target_device)
|
||||
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
with torch.no_grad():
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
|
||||
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
|
||||
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
|
||||
|
||||
if left_padding:
|
||||
# exclude padding on the left
|
||||
max_embed_dim = max_embed_dim.to(target_device)
|
||||
val = (max_embed_dim - embed_indices) <= embed_seq_lens
|
||||
else:
|
||||
# exclude padding on the right
|
||||
val = embed_indices < embed_seq_lens
|
||||
image_to_overwrite &= val
|
||||
|
||||
if image_to_overwrite.sum() != num_image_features:
|
||||
raise ValueError(
|
||||
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
|
||||
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. "
|
||||
f"This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
|
||||
|
||||
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
||||
"""
|
||||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||||
@ -948,14 +717,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
@ -970,14 +739,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import (
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
@ -52,8 +53,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 32001):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
@ -110,7 +109,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32001,
|
||||
projector_hidden_act="gelu",
|
||||
multimodal_projector_bias=True,
|
||||
@ -130,7 +128,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
self.spatial_pool_stride = spatial_pool_stride
|
||||
self.image_seq_length = image_seq_length
|
||||
self.video_seq_length = video_seq_length
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
@ -479,14 +476,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
@ -501,14 +498,14 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
@ -250,7 +251,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
@ -712,19 +713,15 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
image_newline=self.image_newline,
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
@ -741,18 +738,14 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
special_video_mask = (
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
|
||||
|
||||
|
@ -22,10 +22,10 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
AttentionMaskConverter,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
@ -98,6 +98,7 @@ class OPTAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
layer_idx: int = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -106,6 +107,13 @@ class OPTAttention(nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.dropout = config.attention_dropout
|
||||
self.enable_bias = config.enable_bias
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.is_causal = True
|
||||
@ -122,9 +130,6 @@ class OPTAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -134,52 +139,33 @@ class OPTAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
# isn't needed in normal attention, but needed in flash attention so to keep the signature same
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(3, 2))
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
||||
if attn_weights.dtype == torch.float16:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
||||
else:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
@ -187,39 +173,19 @@ class OPTAttention(nn.Module):
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_probs, value_states)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
return attn_output, attn_probs, past_key_value
|
||||
|
||||
|
||||
class OptFlashAttention2(OPTAttention):
|
||||
@ -245,33 +211,33 @@ class OptFlashAttention2(OPTAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, _, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
bsz, query_length, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
# get key, value proj
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
query_length = query_states.shape[1]
|
||||
tgt_len = key_states.shape[-2]
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
|
||||
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
attn_dropout = self.dropout if self.training else 0.0
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
@ -331,6 +297,7 @@ class OPTSdpaAttention(OPTAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
logger.warning_once(
|
||||
@ -344,24 +311,24 @@ class OPTSdpaAttention(OPTAttention):
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
) # TODO after merge add position_ids=position_ids
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
query_states = self._shape(query_states, -1, bsz)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
# shape now is (bsz, num_heads, seq_len, head_dim), all are continuous
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
@ -378,10 +345,6 @@ class OPTSdpaAttention(OPTAttention):
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
# this model uses the scaling factor in the query projection for some reason, but not in Q@K^T
|
||||
# so we need to scale to remove scaling in SDPA to have similar results with eager.
|
||||
# Maybe needs a change in the model to remove scaling in query projection
|
||||
scale=1.0,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
@ -399,11 +362,11 @@ OPT_ATTENTION_CLASSES = {
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(self, config: OPTConfig, layer_idx: int = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
||||
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
@ -425,6 +388,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -440,6 +404,8 @@ class OPTDecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence..
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@ -456,6 +422,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -524,6 +491,9 @@ class OPTPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["OPTDecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -601,6 +571,10 @@ OPT_INPUTS_DOCSTRING = r"""
|
||||
config.n_positions - 1]`. for padding use -1.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -643,9 +617,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -657,48 +629,130 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
input_shape: Tuple[int, int],
|
||||
past_key_values_length: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Updates the causal mask for the decoder.
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
batch_size, seq_length = input_shape
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
if attention_mask is None
|
||||
else attention_mask
|
||||
)
|
||||
|
||||
return causal_attention_mask, attention_mask
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
||||
elif attention_mask.shape[1] != mask_seq_length:
|
||||
raise ValueError(
|
||||
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
|
||||
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
||||
)
|
||||
if self._use_sdpa and not output_attentions and head_mask is None:
|
||||
causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_attention_mask, attention_mask
|
||||
return causal_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -712,6 +766,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -764,6 +819,10 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
config.n_positions - 1]`. for padding use -1.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -773,51 +832,65 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
if past_key_values is None:
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
|
||||
"You should pass an instance of `DynamicCache` instead, e.g. "
|
||||
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
|
||||
causal_attention_mask, attention_mask = self._update_causal_mask(
|
||||
inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
seq_length = past_seen_tokens + inputs_embeds.shape[1]
|
||||
attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# embed positions
|
||||
|
||||
# embed positions
|
||||
if position_ids is None:
|
||||
# position_ids = cache_position.unsqueeze(0)
|
||||
position_ids = torch.cumsum(attention_mask, dim=1)
|
||||
position_ids = (position_ids * attention_mask - 1).long()
|
||||
# cut positions if `past_key_values_length` is > 0
|
||||
position_ids = position_ids[:, past_key_values_length:]
|
||||
# cut positions if `past_seen_tokens` is > 0
|
||||
position_ids = position_ids[:, past_seen_tokens:]
|
||||
|
||||
pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
|
||||
pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids)
|
||||
|
||||
if self.project_in is not None:
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
||||
@ -838,34 +911,34 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if dropout_probability < self.layerdrop:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_attention_mask,
|
||||
causal_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
None,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
position_ids,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
@ -881,6 +954,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
@ -930,6 +1006,7 @@ class OPTModel(OPTPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -950,6 +1027,7 @@ class OPTModel(OPTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@ -1008,6 +1086,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
@ -1069,6 +1148,10 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
config.n_positions - 1]`. for padding use -1.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -1107,6 +1190,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
|
@ -29,6 +29,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -508,7 +509,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
|
@ -38,8 +38,6 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
|
||||
Defaults to `LlamaConfig` if not indicated.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 32000):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_index (`int`, *optional*, defaults to 32001):
|
||||
@ -88,7 +86,6 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32000,
|
||||
video_token_index=32001,
|
||||
projector_hidden_act="gelu",
|
||||
@ -99,7 +96,6 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.video_token_index = video_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -137,6 +138,8 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
@ -276,92 +279,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _merge_input_ids_with_visual_features(
|
||||
self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1
|
||||
):
|
||||
num_images, num_image_patches, embed_dim = visual_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index
|
||||
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == special_vision_token
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != special_vision_token)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = (
|
||||
torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1
|
||||
)
|
||||
nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
# expand input ids so that the second "merge" with videos does not fail
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_input_ids = torch.full(
|
||||
(batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
else:
|
||||
final_labels = None
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
||||
image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
if left_padding:
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
else:
|
||||
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
|
||||
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
|
||||
image_to_overwrite &= padding_mask
|
||||
|
||||
if image_to_overwrite.sum() != visual_features.shape[:-1].numel():
|
||||
visual_type = "videos" if num_frames == 8 else "images"
|
||||
num_images //= num_frames
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values_images: torch.FloatTensor,
|
||||
@ -579,14 +496,14 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
@ -595,14 +512,14 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
|
||||
)
|
||||
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_features = video_features.shape[0] * video_features.shape[1]
|
||||
if n_video_tokens != n_video_features:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum()
|
||||
n_video_features = video_features.shape[0] * video_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
|
@ -37,8 +37,6 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
Custom vision config or dict
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
image_token_index (`int`, *optional*, defaults to 32000):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
@ -78,7 +76,6 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32000,
|
||||
projector_hidden_act="gelu",
|
||||
projector_layernorm_eps=1e-5,
|
||||
@ -86,7 +83,6 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
image_seq_length=576,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.projector_layernorm_eps = projector_layernorm_eps
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -137,6 +138,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of VipLlava isn't meant for training from scratch - only
|
||||
@ -297,89 +300,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||
# 1. Create a mask to know where special image tokens are
|
||||
special_image_token_mask = input_ids == self.config.image_token_index
|
||||
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
||||
# Compute the maximum embed dimension
|
||||
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
||||
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
# Calculate new positions for text tokens in merged image-text sequence.
|
||||
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
||||
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
||||
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
||||
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
||||
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
||||
if left_padding:
|
||||
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
||||
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
||||
|
||||
# 3. Create the full embedding, already padded to the maximum position
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
)
|
||||
final_attention_mask = torch.zeros(
|
||||
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
||||
)
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
||||
# set the corresponding tensors into their correct target device.
|
||||
target_device = inputs_embeds.device
|
||||
batch_indices, non_image_indices, text_to_overwrite = (
|
||||
batch_indices.to(target_device),
|
||||
non_image_indices.to(target_device),
|
||||
text_to_overwrite.to(target_device),
|
||||
)
|
||||
attention_mask = attention_mask.to(target_device)
|
||||
|
||||
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
||||
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
||||
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
||||
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
||||
if labels is not None:
|
||||
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
||||
|
||||
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
||||
image_to_overwrite = torch.full(
|
||||
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
||||
if left_padding:
|
||||
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
||||
else:
|
||||
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
|
||||
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
|
||||
image_to_overwrite &= padding_mask
|
||||
|
||||
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
||||
raise ValueError(
|
||||
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
||||
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
||||
)
|
||||
|
||||
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
||||
final_attention_mask |= image_to_overwrite
|
||||
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
||||
|
||||
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
||||
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
||||
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
||||
|
||||
final_embedding[batch_indices, indices_to_mask] = 0
|
||||
|
||||
if labels is None:
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@ -469,14 +389,14 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||
)
|
||||
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -1783,12 +1783,12 @@ class GenerationTesterMixin:
|
||||
model.config.use_cache = True
|
||||
model.config.is_decoder = True
|
||||
batch_size = input_ids.shape[0]
|
||||
max_length = 30
|
||||
max_new_tokens = 10
|
||||
|
||||
# here we force to not stop at eos and go until max-length
|
||||
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
||||
generation_kwargs = {
|
||||
"max_length": max_length,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
@ -1811,10 +1811,11 @@ class GenerationTesterMixin:
|
||||
|
||||
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
|
||||
# -1 because the last generated token isn't yet in the cache.
|
||||
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
|
||||
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
|
||||
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
|
||||
max_length = max_new_tokens + inputs_embeds.shape[1] - 1
|
||||
cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
|
||||
self.assertIsInstance(outputs.past_key_values, StaticCache)
|
||||
self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
|
||||
self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
@ -2022,7 +2023,7 @@ class GenerationTesterMixin:
|
||||
|
||||
config.is_decoder = True
|
||||
batch_size = main_input.shape[0]
|
||||
seq_length = main_input.shape[-1]
|
||||
seq_length = self.model_tester.seq_length
|
||||
max_new_tokens = 20
|
||||
|
||||
for dtype in (torch.float32, torch.float16):
|
||||
@ -2134,7 +2135,15 @@ class GenerationTesterMixin:
|
||||
# compilation-specific setup
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
@ -2175,7 +2184,14 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertTrue(decoder_cache.is_compileable)
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
|
||||
else:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
@ -2198,9 +2214,19 @@ class GenerationTesterMixin:
|
||||
# compilation-specific setup
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
# BLIP is the only exception with custom generate which call `self.lm.generate()`
|
||||
# We should avoid such calls in all subsequent multimodal models and try to make `generate()`
|
||||
# compatible with multimodality
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
model.language_model.generation_config.compile_config._compile_all_devices = True
|
||||
if not has_defined_cache_implementation:
|
||||
model.language_model.generation_config.cache_implementation = "static"
|
||||
else:
|
||||
# force compilation (e.g. fast CI, CPU)
|
||||
model.generation_config.compile_config._compile_all_devices = True
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
output_generate = model.generate(
|
||||
@ -2218,8 +2244,10 @@ class GenerationTesterMixin:
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
# Sanity check: compilation has happened
|
||||
self.assertTrue(hasattr(model, "_compiled_call"))
|
||||
if "blip" in model.__class__.__name__.lower():
|
||||
self.assertTrue(hasattr(model.language_model, "_compiled_call"))
|
||||
else:
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
|
@ -286,10 +286,18 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -816,6 +816,10 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
|
||||
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
|
||||
class Blip2TextModelTester:
|
||||
|
@ -386,10 +386,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("VQ-VAE module doesn't initialize weights properly")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
@ -256,12 +256,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step"
|
||||
)
|
||||
def test_generate_compile_1_end_to_end(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
@ -838,6 +838,14 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||
def test_model(self):
|
||||
pass
|
||||
|
@ -530,6 +530,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -546,6 +546,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -316,14 +316,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
@ -365,22 +365,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CPU offload is not yet supported")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
@ -391,6 +375,10 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA Next has dynamic control flow in unpadding")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -382,26 +382,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="CPU offload is not yet supported")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
|
||||
)
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
|
||||
)
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
@ -412,6 +392,10 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA Next Video has dynamic control flow in unpadding")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -346,6 +346,10 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA OneVision has dynamic control flow in unpadding")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -540,7 +540,6 @@ class MT5ModelTester:
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
@ -81,7 +81,7 @@ class OPTModelTester:
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_position_embeddings=50,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
@ -89,7 +89,6 @@ class OPTModelTester:
|
||||
num_labels=3,
|
||||
word_embed_proj_dim=16,
|
||||
type_sequence_label_size=2,
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -113,7 +112,6 @@ class OPTModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.word_embed_proj_dim = word_embed_proj_dim
|
||||
self.is_encoder_decoder = False
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||
@ -143,7 +141,6 @@ class OPTModelTester:
|
||||
embed_dim=self.embed_dim,
|
||||
is_encoder_decoder=False,
|
||||
word_embed_proj_dim=self.word_embed_proj_dim,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
|
@ -545,7 +545,6 @@ class T5ModelTester:
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
@ -226,14 +226,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
@ -306,14 +306,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because it is not yet supported in LLava")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
@ -4324,10 +4324,6 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
|
||||
self.skipTest(
|
||||
reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
|
||||
)
|
||||
if config.model_type in ["paligemma"]:
|
||||
self.skipTest(
|
||||
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
||||
@ -4778,6 +4774,9 @@ class ModelTesterMixin:
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
continue # this model doesn't accept position ids as input
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user