mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[VLMs] support attention backends (#37576)
* update models * why rename * return attn weights when sdpa * fixes * fix attn implementation composite * fix moshi * add message * add typings * use explicitly all flags for each attn type * fix some tests * import what is needed * kosmos on main has ew attention already, yay * new models in main, run fixup * won't fix kosmos yet * fix-copies * clean up after rebasing * fix tests * style * dont cast attns to fp32 * did we update ruff? oke, let's just do what it asks * fix pixtral after rebase
This commit is contained in:
parent
e296c63cd4
commit
d23aae2b8c
@ -833,6 +833,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if "model_type" in value:
|
||||
# Needs to be set even if it's not in the diff
|
||||
diff["model_type"] = value["model_type"]
|
||||
|
||||
serializable_config_dict[key] = diff
|
||||
elif (
|
||||
key not in default_config_dict
|
||||
@ -1003,6 +1004,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
del d["_commit_hash"]
|
||||
if "_attn_implementation_internal" in d:
|
||||
del d["_attn_implementation_internal"]
|
||||
if "_attn_implementation_autoset" in d:
|
||||
del d["_attn_implementation_autoset"]
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in d:
|
||||
del d["base_model_tp_plan"]
|
||||
|
@ -3430,9 +3430,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Attach architecture to the config
|
||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||
|
||||
# Unset attn implementation so it can be set to another one when loading back
|
||||
model_to_save.config._attn_implementation_autoset = False
|
||||
|
||||
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
|
@ -669,6 +669,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -1409,6 +1410,7 @@ class AriaModel(AriaPreTrainedModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1424,6 +1426,7 @@ class AriaModel(AriaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, AriaModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1470,16 +1473,16 @@ class AriaModel(AriaPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = AriaModelOutputWithPast(
|
||||
return AriaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values if use_cache else None,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
def _create_patch_attention_mask(self, pixel_mask):
|
||||
if pixel_mask is None:
|
||||
@ -1563,7 +1566,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1645,6 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1655,7 +1659,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return AriaCausalLMOutputWithPast(
|
||||
|
@ -32,6 +32,7 @@ from ...image_utils import (
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
@ -40,6 +41,7 @@ from ...tokenization_utils import (
|
||||
TextInput,
|
||||
)
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -1240,6 +1242,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -1290,6 +1293,9 @@ class AriaTextModel(LlamaModel):
|
||||
self.post_init()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
|
||||
"""
|
||||
Aria model for causal language modeling tasks.
|
||||
@ -1434,6 +1440,7 @@ class AriaModel(LlavaModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1449,6 +1456,7 @@ class AriaModel(LlavaModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, AriaModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1495,16 +1503,16 @@ class AriaModel(LlavaModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = AriaModelOutputWithPast(
|
||||
return AriaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values if use_cache else None,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -1533,7 +1541,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1615,6 +1623,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1625,7 +1634,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return AriaCausalLMOutputWithPast(
|
||||
|
@ -27,9 +27,12 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -124,6 +127,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = False
|
||||
_supports_static_cache = False
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
@ -358,6 +362,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -375,7 +380,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, AyaVisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -434,17 +439,19 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = AyaVisionModelOutputWithPast(
|
||||
return AyaVisionModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -512,7 +519,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -589,7 +596,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -599,7 +606,9 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return AyaVisionCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -20,6 +20,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.llava.modeling_llava import (
|
||||
KwargsForCausalLM,
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaModel,
|
||||
@ -27,6 +28,7 @@ from transformers.models.llava.modeling_llava import (
|
||||
)
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
logging,
|
||||
@ -148,7 +150,7 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -213,7 +215,7 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,15 +25,18 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -255,6 +258,30 @@ class Blip2VisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> BLIP doesn't cast attn weights to fp32
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Blip2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -270,7 +297,8 @@ class Blip2Attention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = nn.Dropout(config.attention_dropout)
|
||||
self.is_causal = False
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
# small tweak here compared to CLIP, no bias here
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
||||
@ -296,6 +324,7 @@ class Blip2Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -308,31 +337,32 @@ class Blip2Attention(nn.Module):
|
||||
)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
attention_scores = attention_scores * self.scale
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
|
||||
outputs = (output, attention_probs) if output_attentions else (output, None)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.projection(attn_output)
|
||||
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -410,6 +440,10 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
config_class = Blip2Config
|
||||
base_model_prefix = "blip"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
_no_split_modules = [
|
||||
"Blip2Attention",
|
||||
@ -1332,6 +1366,11 @@ class Blip2TextEmbeddings(nn.Module):
|
||||
BLIP_2_QFORMER_START_DOCSTRING,
|
||||
)
|
||||
class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
_supports_attention_backend = False # adds position on attn weights before last matmul
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_supports_flex_attn = False
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -1511,6 +1550,9 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
|
||||
@ -1526,10 +1568,10 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
|
||||
|
||||
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
||||
if config.use_decoder_only_language_model:
|
||||
@ -1580,6 +1622,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@ -1611,6 +1654,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
@ -1624,6 +1668,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return text_outputs
|
||||
@ -1749,6 +1794,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@ -1826,6 +1872,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
@ -1851,6 +1898,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True, # toggle for easier access to loss/logits below
|
||||
labels=labels,
|
||||
**kwargs,
|
||||
)
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
@ -1981,10 +2029,10 @@ class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
|
||||
|
||||
# vision projection layer
|
||||
self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
|
||||
@ -2102,10 +2150,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
|
||||
|
||||
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
||||
if config.use_decoder_only_language_model:
|
||||
@ -2180,6 +2228,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@ -2308,6 +2357,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
@ -2334,6 +2384,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
return_dict=True, # toggle for easier access to loss/logits below
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
@ -2463,12 +2514,12 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
|
||||
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
|
||||
|
||||
# vision projection layer
|
||||
self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
|
||||
|
@ -14,31 +14,32 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Chameleon model."""
|
||||
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
@ -235,6 +236,33 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class ChameleonAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -259,6 +287,7 @@ class ChameleonAttention(nn.Module):
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.model_parallel_size = config.model_parallel_size
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -338,144 +367,26 @@ class ChameleonAttention(nn.Module):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
|
||||
# TODO(joao): add me back asap :)
|
||||
class ChameleonFlashAttention2(ChameleonAttention):
|
||||
"""
|
||||
Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
query_states = self.q_norm(query_states)
|
||||
|
||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
|
||||
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (ChameleonRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
@ -487,114 +398,13 @@ class ChameleonFlashAttention2(ChameleonAttention):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class ChameleonSdpaAttention(ChameleonAttention):
|
||||
"""
|
||||
Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from ChameleonAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
query_states = self.q_norm(query_states)
|
||||
|
||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None and cache_position is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
CHAMELEON_ATTENTION_CLASSES = {
|
||||
"eager": ChameleonAttention,
|
||||
"flash_attention_2": ChameleonFlashAttention2,
|
||||
"sdpa": ChameleonSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
|
||||
# TODO(joao): add me back asap :)
|
||||
class ChameleonDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ChameleonConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = ChameleonMLP(config)
|
||||
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -669,7 +479,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = ChameleonMLP(config)
|
||||
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -1052,6 +862,7 @@ class ChameleonPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_param_buffer_assignment = False
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -1256,6 +1067,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1342,6 +1154,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1498,6 +1311,9 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Chameleon Model with a head on top used for outputting logits for next token prediction.",
|
||||
CHAMELEON_START_DOCSTRING,
|
||||
@ -1532,6 +1348,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1548,6 +1365,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1596,6 +1414,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1607,22 +1426,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -1010,6 +1010,7 @@ class Emu3VQVAE(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_no_split_modules = [
|
||||
"Emu3VQVAETemporalResnetBlock",
|
||||
"Emu3VQVAEAttentionBlock",
|
||||
@ -1202,6 +1203,7 @@ class Emu3PreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_param_buffer_assignment = False
|
||||
_supports_attention_backend = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
@ -1836,6 +1838,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
image = self.vqmodel.decode(image_tokens)
|
||||
return image
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1851,6 +1854,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1884,8 +1888,9 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
@ -1941,6 +1946,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -2007,8 +2013,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -2018,7 +2025,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -25,10 +25,12 @@ import torch.utils.checkpoint
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -41,6 +43,7 @@ from ..chameleon.modeling_chameleon import (
|
||||
ChameleonVQVAEEncoderConvDownsample,
|
||||
)
|
||||
from ..llama.modeling_llama import (
|
||||
KwargsForCausalLM,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
@ -736,6 +739,7 @@ class Emu3VQVAE(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_no_split_modules = [
|
||||
"Emu3VQVAETemporalResnetBlock",
|
||||
"Emu3VQVAEAttentionBlock",
|
||||
@ -898,6 +902,7 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
|
||||
"Emu3DecoderLayer",
|
||||
]
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
@ -1179,6 +1184,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
image = self.vqmodel.decode(image_tokens)
|
||||
return image
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1194,6 +1200,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1227,8 +1234,9 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
@ -1284,6 +1292,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1350,8 +1359,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1361,7 +1371,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -21,10 +21,19 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...models.auto.modeling_auto import AutoModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_fuyu import FuyuConfig
|
||||
|
||||
|
||||
@ -58,6 +67,10 @@ class FuyuPreTrainedModel(PreTrainedModel):
|
||||
config_class = FuyuConfig
|
||||
base_model_prefix = "fuyu"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_no_split_modules = []
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@ -142,6 +155,9 @@ FUYU_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""",
|
||||
FUYU_START_DOCSTRING,
|
||||
@ -224,8 +240,8 @@ class FuyuModel(FuyuPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -323,6 +339,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -392,7 +409,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
# don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
|
||||
)
|
||||
|
||||
@ -407,10 +424,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -30,9 +30,12 @@ import torch.nn.functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -619,6 +622,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -745,6 +749,7 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
|
||||
image_outputs = self.vision_tower(pixel_values).last_hidden_state
|
||||
return self.multi_modal_projector(image_outputs)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -759,6 +764,7 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -800,16 +806,19 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = GotOcr2ModelOutputWithPast(
|
||||
return GotOcr2ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -874,6 +883,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -937,6 +947,8 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -946,7 +958,9 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return GotOcr2CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -21,6 +21,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers.models.llava.modeling_llava import (
|
||||
KwargsForCausalLM,
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaModel,
|
||||
@ -30,6 +31,8 @@ from transformers.models.llava.modeling_llava import (
|
||||
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -393,6 +396,7 @@ class GotOcr2Model(LlavaModel):
|
||||
image_outputs = self.vision_tower(pixel_values).last_hidden_state
|
||||
return self.multi_modal_projector(image_outputs)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -407,6 +411,7 @@ class GotOcr2Model(LlavaModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -448,16 +453,16 @@ class GotOcr2Model(LlavaModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = GotOcr2ModelOutputWithPast(
|
||||
return GotOcr2ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
@ -479,6 +484,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -542,6 +548,8 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -551,7 +559,9 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return GotOcr2CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -20,24 +20,27 @@
|
||||
"""PyTorch Idefics model."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -500,6 +503,30 @@ class IdeficsMLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# this was adapted from LlamaAttention
|
||||
class IdeficsAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
@ -515,11 +542,13 @@ class IdeficsAttention(nn.Module):
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.dropout = dropout
|
||||
self.is_causal = True
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
@ -596,6 +625,7 @@ class IdeficsAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
||||
@ -631,47 +661,33 @@ class IdeficsAttention(nn.Module):
|
||||
query_states = self.q_layer_norm(query_states)
|
||||
key_states = self.k_layer_norm(key_states)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
attn_weights = None
|
||||
if output_attentions:
|
||||
logger.warning_once(
|
||||
"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
|
||||
)
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
@ -706,6 +722,7 @@ class IdeficsDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -734,6 +751,7 @@ class IdeficsDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -833,6 +851,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -875,6 +894,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||
key_value_states=image_hidden_states,
|
||||
attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
||||
# Fill in zeros for cross_attention hidden_states of tokens attending to no images
|
||||
@ -927,7 +947,9 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Idefics isn't meant for training from scratch - only
|
||||
@ -1029,6 +1051,9 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAMA_START_DOCSTRING,
|
||||
@ -1112,6 +1137,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1130,6 +1156,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
@ -1292,6 +1319,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
past_key_value=None, # not implemented
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
|
||||
@ -1303,6 +1331,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return layer_outputs
|
||||
@ -1348,6 +1377,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
cross_layer_interval=self.cross_layer_interval,
|
||||
gated_cross_attn_layers=self.gated_cross_attn_layers,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1368,12 +1398,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return IdeficsBaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
@ -1565,6 +1590,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
||||
):
|
||||
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1585,6 +1611,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1641,8 +1668,9 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1650,24 +1678,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return IdeficsCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -20,17 +20,20 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -514,6 +517,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -1089,6 +1093,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
|
||||
return new_inputs_embeds
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(
|
||||
"""
|
||||
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
|
||||
@ -1117,6 +1122,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, Idefics2BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1226,15 +1232,13 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if return_legacy_cache and use_cache:
|
||||
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
|
||||
|
||||
return Idefics2BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -1244,6 +1248,9 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
|
||||
IDEFICS2_START_DOCSTRING,
|
||||
@ -1292,6 +1299,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1311,6 +1319,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1386,7 +1395,8 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1396,26 +1406,9 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return Idefics2CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -20,17 +20,20 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -532,6 +535,7 @@ class Idefics3PreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -816,6 +820,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
||||
return new_inputs_embeds
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(
|
||||
"""
|
||||
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
|
||||
@ -843,6 +848,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -939,12 +945,10 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
|
||||
|
||||
return Idefics3BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -954,6 +958,9 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
|
||||
IDEFICS3_START_DOCSTRING,
|
||||
@ -1009,6 +1016,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1028,6 +1036,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1117,7 +1126,8 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1127,26 +1137,9 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return Idefics3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -16,27 +16,30 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
@ -159,6 +162,30 @@ class InstructBlipVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
|
||||
class InstructBlipAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
@ -175,7 +202,8 @@ class InstructBlipAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = nn.Dropout(config.attention_dropout)
|
||||
self.is_causal = False
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
# small tweak here compared to CLIP, no bias here
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
||||
@ -201,6 +229,7 @@ class InstructBlipAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -213,31 +242,32 @@ class InstructBlipAttention(nn.Module):
|
||||
)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
attention_scores = attention_scores * self.scale
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
|
||||
outputs = (output, attention_probs) if output_attentions else (output, None)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.projection(attn_output)
|
||||
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -315,6 +345,10 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
|
||||
config_class = InstructBlipConfig
|
||||
base_model_prefix = "blip"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
@ -1087,6 +1121,11 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
|
||||
instruction as input.
|
||||
"""
|
||||
|
||||
_supports_attention_backend = False # adds position on attn weights before last matmul
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_supports_flex_attn = False
|
||||
|
||||
def __init__(self, config: InstructBlipQFormerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -1277,6 +1316,9 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
InstructBLIP base Model consisting of language model, qformer and vision encoder.
|
||||
@ -1337,6 +1379,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1352,6 +1395,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -1404,6 +1448,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
@ -1415,11 +1460,9 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (vision_outputs, query_outputs, outputs)
|
||||
|
||||
return InstructBlipForConditionalGenerationModelOutput(
|
||||
vision_outputs=vision_outputs,
|
||||
qformer_outputs=query_outputs,
|
||||
@ -1448,10 +1491,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
def __init__(self, config: InstructBlipConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = InstructBlipVisionModel(config.vision_config)
|
||||
self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
self.qformer = InstructBlipQFormerModel(config.qformer_config)
|
||||
self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
|
||||
|
||||
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
||||
|
||||
@ -1516,6 +1559,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=InstructBlipForConditionalGenerationModelOutput, config_class=InstructBlipVisionConfig
|
||||
@ -1535,6 +1579,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1646,21 +1691,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="mean")
|
||||
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -1672,14 +1711,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, vision_outputs, query_outputs, outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return InstructBlipForConditionalGenerationModelOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -21,26 +21,29 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
@ -130,6 +133,30 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class InstructBlipVideoAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -145,7 +172,8 @@ class InstructBlipVideoAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = nn.Dropout(config.attention_dropout)
|
||||
self.is_causal = False
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
# small tweak here compared to CLIP, no bias here
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
||||
@ -171,6 +199,7 @@ class InstructBlipVideoAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -183,31 +212,32 @@ class InstructBlipVideoAttention(nn.Module):
|
||||
)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
attention_scores = attention_scores * self.scale
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask=None,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.projection(context_layer)
|
||||
|
||||
outputs = (output, attention_probs) if output_attentions else (output, None)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.projection(attn_output)
|
||||
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -852,6 +882,9 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
@ -945,6 +978,10 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
||||
config_class = InstructBlipVideoConfig
|
||||
base_model_prefix = "blip"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
@ -1049,6 +1086,11 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
|
||||
instruction as input.
|
||||
"""
|
||||
|
||||
_supports_attention_backend = False # adds position on attn weights before last matmul
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_supports_flex_attn = False
|
||||
|
||||
def __init__(self, config: InstructBlipVideoQFormerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -1332,6 +1374,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -1347,6 +1390,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -1409,6 +1453,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
@ -1420,11 +1465,9 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (vision_outputs, query_outputs, outputs)
|
||||
|
||||
return InstructBlipVideoForConditionalGenerationModelOutput(
|
||||
vision_outputs=vision_outputs,
|
||||
qformer_outputs=query_outputs,
|
||||
@ -1453,10 +1496,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
def __init__(self, config: InstructBlipVideoConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
|
||||
self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
|
||||
self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config)
|
||||
|
||||
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
||||
|
||||
@ -1519,9 +1562,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoVisionConfig
|
||||
output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoConfig
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1538,6 +1582,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -1682,21 +1727,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="mean")
|
||||
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -1708,14 +1747,11 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, vision_outputs, query_outputs, outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return InstructBlipVideoForConditionalGenerationModelOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.instructblip.configuration_instructblip import (
|
||||
InstructBlipQFormerConfig,
|
||||
@ -31,11 +30,14 @@ from transformers.models.instructblip.modeling_instructblip import (
|
||||
InstructBlipPreTrainedModel,
|
||||
InstructBlipQFormerModel,
|
||||
InstructBlipVisionModel,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from ...utils import add_start_docstrings_to_model_forward, logging
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging, replace_return_docstrings
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
@ -196,6 +198,7 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None
|
||||
|
||||
|
||||
class InstructBlipVideoModel(InstructBlipModel):
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -211,6 +214,7 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -273,6 +277,7 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
@ -284,11 +289,9 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (vision_outputs, query_outputs, outputs)
|
||||
|
||||
return InstructBlipVideoForConditionalGenerationModelOutput(
|
||||
vision_outputs=vision_outputs,
|
||||
qformer_outputs=query_outputs,
|
||||
@ -297,6 +300,11 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
|
||||
|
||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoConfig
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
@ -312,6 +320,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
```python
|
||||
@ -447,21 +456,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="mean")
|
||||
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -473,14 +476,11 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, vision_outputs, query_outputs, outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return InstructBlipVideoForConditionalGenerationModelOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -35,6 +35,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseMo
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -621,6 +622,7 @@ class InternVLPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -828,6 +830,7 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
|
||||
return vision_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -845,7 +848,7 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, InternVLModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -904,17 +907,16 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = InternVLModelOutputWithPast(
|
||||
return InternVLModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
|
||||
"""Perform pixel shuffle downsampling on vision features.
|
||||
@ -992,6 +994,9 @@ class InternVLCausalLMOutputWithPast(ModelOutput):
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The INTERNVL model which consists of a vision backbone and a language model.""",
|
||||
INTERNVL_START_DOCSTRING,
|
||||
@ -1056,8 +1061,8 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1138,7 +1143,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1148,7 +1153,9 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return InternVLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -21,10 +21,10 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -32,10 +32,13 @@ from ...modeling_outputs import (
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
@ -466,7 +469,7 @@ class Kosmos2VisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -481,7 +484,7 @@ def eager_attention_forward(
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
@ -892,6 +895,7 @@ class KosmosTextAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
@ -929,6 +933,7 @@ class KosmosTextAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -953,8 +958,7 @@ class KosmosTextAttention(nn.Module):
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
query_states = self._shape(self.q_proj(hidden_states) * self.scaling)
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
query_states = self._shape(self.q_proj(hidden_states))
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
@ -966,32 +970,33 @@ class KosmosTextAttention(nn.Module):
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
src_len = key_states.size(2)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (batch_size, 1, seq_length, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(batch_size, 1, seq_length, src_len)}, but is {attention_mask.size()}"
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
# Mask heads if we want to
|
||||
if layer_head_mask is not None:
|
||||
attn_weights = attn_weights * layer_head_mask
|
||||
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
# attn_output = torch.bmm(attn_probs, value_states) ?
|
||||
context_states = torch.matmul(attn_weights, value_states)
|
||||
# attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ?
|
||||
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||
if self.inner_attn_ln is not None:
|
||||
context_states = self.inner_attn_ln(context_states)
|
||||
attn_output = self.inner_attn_ln(attn_output)
|
||||
|
||||
attn_output = self.out_proj(context_states)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
@ -1060,6 +1065,7 @@ class Kosmos2TextBlock(nn.Module):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
@ -1076,6 +1082,7 @@ class Kosmos2TextBlock(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -1103,6 +1110,7 @@ class Kosmos2TextBlock(nn.Module):
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -1216,6 +1224,7 @@ class Kosmos2TextTransformer(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -1233,6 +1242,7 @@ class Kosmos2TextTransformer(nn.Module):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1338,6 +1348,7 @@ class Kosmos2TextTransformer(nn.Module):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1357,18 +1368,6 @@ class Kosmos2TextTransformer(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
present_key_value_states,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=present_key_value_states,
|
||||
@ -1387,6 +1386,9 @@ class Kosmos2PreTrainedModel(PreTrainedModel):
|
||||
config_class = Kosmos2Config
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -1525,6 +1527,7 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
|
||||
def forward(
|
||||
@ -1544,6 +1547,7 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
r"""
|
||||
Returns:
|
||||
@ -1565,9 +1569,13 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
|
||||
@ -1600,6 +1608,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2TextConfig)
|
||||
def forward(
|
||||
@ -1620,6 +1629,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1652,27 +1662,14 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
@ -1804,6 +1801,7 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1822,6 +1820,7 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, Kosmos2ModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@ -1893,13 +1892,10 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
outputs = outputs + (image_embeds, projection_attentions, vision_model_output)
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return Kosmos2ModelOutput(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -1946,6 +1942,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.text_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1964,6 +1961,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Kosmos2ForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -2048,13 +2046,10 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
outputs = lm_outputs + (image_embeds, projection_attentions, vision_model_output)
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return Kosmos2ForConditionalGenerationModelOutput(
|
||||
loss=lm_outputs.loss,
|
||||
logits=lm_outputs.logits,
|
||||
|
@ -39,8 +39,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -244,6 +246,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -256,12 +259,13 @@ def eager_attention_forward(
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype)
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
@ -607,6 +611,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -712,13 +717,12 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
output = BaseModelOutputWithPast(
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
@torch.compiler.disable(recursive=False) # the operations in this method are not compilable
|
||||
def _update_causal_mask(
|
||||
@ -931,6 +935,9 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
||||
_no_split_modules = ["Llama4TextDecoderLayer"]
|
||||
base_model_prefix = "language_model"
|
||||
@ -965,6 +972,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -981,7 +989,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1031,7 +1039,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1044,10 +1052,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1208,6 +1212,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.num_key_value_groups = 1
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
|
||||
@ -1593,7 +1598,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
_tp_plan = {}
|
||||
base_model_prefix = ""
|
||||
config_class = Llama4Config
|
||||
_supports_flex_attn = True
|
||||
|
||||
def __init__(self, config: Llama4Config):
|
||||
super().__init__(config)
|
||||
@ -1673,7 +1677,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1780,7 +1784,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
@ -23,9 +23,12 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -171,6 +174,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Llava isn't meant for training from scratch - only
|
||||
@ -331,6 +335,7 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -348,7 +353,7 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, LlavaModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -407,17 +412,19 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = LlavaModelOutputWithPast(
|
||||
return LlavaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -484,8 +491,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -553,7 +560,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -563,7 +570,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return LlavaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -26,9 +26,12 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -280,6 +283,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -528,6 +532,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -545,7 +550,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, LlavaNextModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -609,17 +614,19 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = LlavaNextModelOutputWithPast(
|
||||
return LlavaNextModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -688,7 +695,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -756,7 +763,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -766,7 +773,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return LlavaNextCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -30,9 +30,12 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -223,6 +226,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -581,6 +585,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -599,7 +604,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -684,10 +689,10 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = LlavaNextVideoModelOutputWithPast(
|
||||
return LlavaNextVideoModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -695,7 +700,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
def get_video_features(
|
||||
self,
|
||||
@ -744,6 +748,9 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
return video_features
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
|
||||
LLAVA_NEXT_VIDEO_START_DOCSTRING,
|
||||
@ -811,7 +818,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
|
||||
r"""
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
|
||||
@ -915,10 +922,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -928,7 +935,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return LlavaNextVideoCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -22,6 +22,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
KwargsForCausalLM,
|
||||
LlavaNextCausalLMOutputWithPast,
|
||||
LlavaNextForConditionalGeneration,
|
||||
LlavaNextModel,
|
||||
@ -31,6 +32,8 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
)
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -378,7 +381,7 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -463,10 +466,10 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = LlavaNextVideoModelOutputWithPast(
|
||||
return LlavaNextVideoModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -474,7 +477,6 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
|
||||
@ -580,7 +582,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
|
||||
r"""
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
|
||||
@ -684,10 +686,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -697,7 +699,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return LlavaNextVideoCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -30,9 +30,12 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
@ -405,6 +408,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -570,6 +574,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -590,7 +595,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -681,10 +686,10 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = LlavaOnevisionModelOutputWithPast(
|
||||
return LlavaOnevisionModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -693,8 +698,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||
)
|
||||
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
@ -756,6 +759,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
return image_features
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
|
||||
LLAVA_ONEVISION_START_DOCSTRING,
|
||||
@ -824,7 +830,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -909,7 +915,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -919,7 +925,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return LlavaOnevisionCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -21,6 +21,7 @@ from torch import nn
|
||||
|
||||
from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast
|
||||
from transformers.models.llava_next_video.modeling_llava_next_video import (
|
||||
KwargsForCausalLM,
|
||||
LlavaNextVideoCausalLMOutputWithPast,
|
||||
LlavaNextVideoForConditionalGeneration,
|
||||
LlavaNextVideoModel,
|
||||
@ -36,6 +37,8 @@ from ...image_utils import (
|
||||
OPENAI_CLIP_STD,
|
||||
PILImageResampling,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
|
||||
|
||||
@ -217,6 +220,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
|
||||
return video_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -237,7 +241,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -328,10 +332,10 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = LlavaOnevisionModelOutputWithPast(
|
||||
return LlavaOnevisionModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -340,8 +344,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||
)
|
||||
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration):
|
||||
@can_return_tuple
|
||||
@ -367,7 +369,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -452,7 +454,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -462,7 +464,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return LlavaOnevisionCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -28,9 +28,12 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -233,6 +236,7 @@ class Mistral3PreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Mistral3 isn't meant for training from scratch - only
|
||||
@ -251,6 +255,144 @@ class Mistral3PreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""",
|
||||
MISTRAL3_START_DOCSTRING,
|
||||
)
|
||||
class Mistral3Model(Mistral3PreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = Mistral3MultiModalProjector(config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
image_sizes: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_sizes (`torch.Tensor`):
|
||||
Tensor containing the image sizes as returned by the processor.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return Mistral3ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
MISTRAL3_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -328,142 +470,6 @@ MISTRAL3_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""",
|
||||
MISTRAL3_START_DOCSTRING,
|
||||
)
|
||||
class Mistral3Model(Mistral3PreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: Mistral3Config):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = Mistral3MultiModalProjector(config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
image_sizes: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_sizes (`torch.Tensor`):
|
||||
Tensor containing the image sizes as returned by the processor.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
output = Mistral3ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
|
||||
MISTRAL3_START_DOCSTRING,
|
||||
@ -526,8 +532,8 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -585,7 +591,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -595,7 +601,9 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -19,6 +19,8 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -27,6 +29,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..llava.modeling_llava import (
|
||||
KwargsForCausalLM,
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaModel,
|
||||
@ -174,6 +177,7 @@ class Mistral3Model(LlavaModel):
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -189,7 +193,7 @@ class Mistral3Model(LlavaModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -239,17 +243,16 @@ class Mistral3Model(LlavaModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = Mistral3ModelOutputWithPast(
|
||||
return Mistral3ModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
@ -271,8 +274,8 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**lm_kwargs,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -330,7 +333,7 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
image_sizes=image_sizes,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -340,7 +343,9 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -15,21 +15,24 @@
|
||||
"""PyTorch Mllama model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -180,13 +183,56 @@ class MllamaVisionMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class MllamaVisionAttention(nn.Module):
|
||||
def __init__(self, config: MllamaVisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.attention_heads
|
||||
self.head_dim = config.hidden_size // config.attention_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.num_key_value_groups = 1
|
||||
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
|
||||
@ -198,6 +244,7 @@ class MllamaVisionAttention(nn.Module):
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
query = self.q_proj(hidden_state)
|
||||
key = self.k_proj(hidden_state)
|
||||
@ -210,73 +257,35 @@ class MllamaVisionAttention(nn.Module):
|
||||
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||
|
||||
output = self.o_proj(attn_output)
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return output, attn_weights
|
||||
|
||||
|
||||
class MllamaVisionSdpaAttention(MllamaVisionAttention):
|
||||
# Adapted from MllamaVisionAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
if output_attentions:
|
||||
logger.warning_once(
|
||||
"MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_state=hidden_state,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
query = self.q_proj(hidden_state)
|
||||
key = self.k_proj(hidden_state)
|
||||
value = self.v_proj(hidden_state)
|
||||
|
||||
batch_size, q_seq_len, _ = query.shape
|
||||
_, kv_seq_len, _ = key.shape
|
||||
|
||||
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
|
||||
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||
|
||||
output = self.o_proj(attn_output)
|
||||
|
||||
return output, None
|
||||
|
||||
|
||||
MLLAMA_VISION_ATTENTION_CLASSES = {"eager": MllamaVisionAttention, "sdpa": MllamaVisionSdpaAttention}
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class MllamaVisionEncoderLayer(nn.Module):
|
||||
@ -288,7 +297,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
self.is_gated = is_gated
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = MLLAMA_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.self_attn = MllamaVisionAttention(config)
|
||||
self.mlp = MllamaVisionMLP(config)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
||||
@ -453,6 +462,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
self.layer_idx = layer_idx
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
@ -471,6 +481,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -503,17 +514,29 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
|
||||
)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -522,100 +545,6 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class MllamaTextCrossSdpaAttention(MllamaTextCrossAttention):
|
||||
"""
|
||||
Mllama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`MllamaTextCrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from MllamaTextCrossAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MllamaModel is using MllamaTextCrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
cross_attention_states=cross_attention_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
query_states = self.q_norm(query_states)
|
||||
|
||||
if cross_attention_states is not None:
|
||||
key_states = self.k_proj(cross_attention_states)
|
||||
value_states = self.v_proj(cross_attention_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# if we have a new image + new tokens, we only computed key_states on that new image
|
||||
# we still update the cross key states, past_image, new_image. And use it!
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
elif cache_position[0] != 0:
|
||||
key_states, value_states = (
|
||||
past_key_value.key_cache[self.layer_idx],
|
||||
past_key_value.value_cache[self.layer_idx],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
|
||||
)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if attention_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
@ -652,19 +581,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class MllamaTextSelfAttention(nn.Module):
|
||||
def __init__(self, config: MllamaTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
@ -675,6 +591,7 @@ class MllamaTextSelfAttention(nn.Module):
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
@ -712,23 +629,29 @@ class MllamaTextSelfAttention(nn.Module):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -737,92 +660,6 @@ class MllamaTextSelfAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class MllamaTextSelfSdpaAttention(MllamaTextSelfAttention):
|
||||
# Adapted from MllamaTextSelfAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
past_key_value=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MllamaModel is using MllamaTextSelfSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = {"eager": MllamaTextCrossAttention, "sdpa": MllamaTextCrossSdpaAttention}
|
||||
MLLAMA_TEXT_ATTENTION_CLASSES = {"eager": MllamaTextSelfAttention, "sdpa": MllamaTextSelfSdpaAttention}
|
||||
|
||||
|
||||
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
|
||||
class MllamaTextMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -847,7 +684,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
self.self_attn = MllamaTextSelfAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MllamaTextMLP(config)
|
||||
self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -868,6 +705,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -905,6 +743,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -931,7 +770,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
||||
self.cross_attn = MllamaTextCrossAttention(config, layer_idx=layer_idx)
|
||||
|
||||
self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
|
||||
@ -953,6 +792,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -964,6 +804,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||
|
||||
@ -1026,7 +867,9 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False # static cache cannot have different shapes for each layer
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -1694,6 +1537,7 @@ class MllamaTextModel(MllamaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
"""
|
||||
|
||||
@ -1804,6 +1648,7 @@ class MllamaTextModel(MllamaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1832,6 +1677,9 @@ class MllamaTextModel(MllamaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The Mllama Text Model with a language modeling head on top.""",
|
||||
MLLAMA_START_DOCSTRING,
|
||||
@ -1888,7 +1736,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1945,6 +1793,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1953,7 +1802,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
@ -1999,6 +1848,7 @@ class MllamaModel(MllamaPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -2017,6 +1867,7 @@ class MllamaModel(MllamaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -2079,15 +1930,15 @@ class MllamaModel(MllamaPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = BaseModelOutputWithPast(
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -2153,7 +2004,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -2220,6 +2071,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -2229,7 +2081,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -236,7 +236,7 @@ class MoshiConfig(PretrainedConfig):
|
||||
|
||||
model_type = "moshi"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {"audio_encoder_config": AutoConfig}
|
||||
sub_configs = {"audio_encoder_config": AutoConfig, "depth_decoder_config": MoshiDepthConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1907,7 +1907,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
|
||||
self.audio_encoder = AutoModel.from_config(config.audio_encoder_config)
|
||||
self.decoder = MoshiForCausalLM(config)
|
||||
|
||||
self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config)
|
||||
self.depth_decoder = MoshiDepthDecoder._from_config(config.depth_decoder_config)
|
||||
|
||||
self.num_codebooks = config.num_codebooks
|
||||
self.post_init()
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch OPT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -27,18 +27,21 @@ from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import (
|
||||
AttentionMaskConverter,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -53,7 +56,7 @@ if is_torch_flex_attn_available():
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -98,6 +101,30 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
return super().forward(position_ids + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class OPTAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -143,9 +170,8 @@ class OPTAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
# isn't needed in normal attention, but needed in flash attention so to keep the signature same
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
@ -165,206 +191,35 @@ class OPTAttention(nn.Module):
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(3, 2))
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
|
||||
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_probs, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_probs, past_key_value
|
||||
|
||||
|
||||
class OptFlashAttention2(OPTAttention):
|
||||
"""
|
||||
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
||||
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|
||||
attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, query_length, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
attn_dropout = self.dropout if self.training else 0.0
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
query_length,
|
||||
position_ids=position_ids,
|
||||
dropout=attn_dropout,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
|
||||
attn_output = self.out_proj(attn_weights_reshaped)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class OPTSdpaAttention(OPTAttention):
|
||||
"""
|
||||
OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
|
||||
The only required change would be on the forward pass where it needs to correctly call the public API of sdpa
|
||||
attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
logger.warning_once(
|
||||
"OPTModel is using SDPA attention, which currently does not support output_attentions=True."
|
||||
'failing back to eager attention. remove warning using attn_implementation="eager".'
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
|
||||
OPT_ATTENTION_CLASSES = {
|
||||
"eager": OPTAttention,
|
||||
"flash_attention_2": OptFlashAttention2,
|
||||
"sdpa": OPTSdpaAttention,
|
||||
}
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
@ -372,7 +227,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
self.self_attn = OPTAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
@ -395,6 +250,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -429,6 +285,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -495,8 +352,10 @@ class OPTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["OPTDecoderLayer"]
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
@ -763,6 +622,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
|
||||
return causal_mask
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -776,6 +636,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -942,6 +803,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -966,8 +828,6 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
@ -996,6 +856,7 @@ class OPTModel(OPTPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
@ -1016,6 +877,7 @@ class OPTModel(OPTPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1035,13 +897,11 @@ class OPTModel(OPTPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
@ -1050,6 +910,9 @@ class OPTModel(OPTPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
@ -1081,6 +944,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@can_return_tuple
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
@ -1096,7 +960,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1198,8 +1062,9 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
@ -1215,10 +1080,6 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -23,9 +23,12 @@ from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@ -159,6 +162,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
||||
@ -352,6 +356,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -368,7 +373,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, PaligemmaModelOutputWithPast]:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
@ -436,17 +441,19 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = PaligemmaModelOutputWithPast(
|
||||
return PaligemmaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -512,7 +519,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -570,7 +577,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -580,7 +587,9 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return PaliGemmaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -21,14 +21,18 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_rope_utils import dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from .configuration_pixtral import PixtralVisionConfig
|
||||
|
||||
|
||||
@ -132,7 +136,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.smolvlm.modeling_smolvlm.eager_attention_forward
|
||||
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -167,10 +171,11 @@ class PixtralAttention(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_causal = False
|
||||
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
@ -211,28 +216,22 @@ class PixtralAttention(nn.Module):
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
# Since we use packing, if Flash-Attn 2 is selected we rely on position_ids
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True)
|
||||
attention_mask = None
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
is_causal=self.is_causal,
|
||||
output_attentions=output_attentions,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@ -288,7 +287,7 @@ class PixtralAttentionLayer(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -341,7 +340,7 @@ class PixtralTransformer(nn.Module):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
@ -383,7 +382,6 @@ class PixtralTransformer(nn.Module):
|
||||
attention_mask,
|
||||
position_embeddings,
|
||||
output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
@ -431,6 +429,10 @@ class PixtralPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_no_split_modules = ["PixtralAttentionLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
@ -508,6 +510,7 @@ class PixtralVisionModel(PixtralPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.patch_conv
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -517,7 +520,7 @@ class PixtralVisionModel(PixtralPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
"""
|
||||
Returns:
|
||||
@ -551,17 +554,15 @@ class PixtralVisionModel(PixtralPreTrainedModel):
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||
)
|
||||
|
||||
out = self.transformer(
|
||||
return self.transformer(
|
||||
patch_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]
|
||||
|
@ -24,17 +24,20 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -78,6 +81,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
@ -574,80 +578,6 @@ class SmolVLMConnector(nn.Module):
|
||||
return image_hidden_states
|
||||
|
||||
|
||||
SMOLVLM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
|
||||
[`CLIPImageProcessor`] for processing images).
|
||||
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
|
||||
Mask to avoid performing attention on padding pixel indices.
|
||||
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The hidden states of the image encoder after modality projection.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""SmolVLM model consisting of a SIGLIP vision encoder and Llama3 language decoder""",
|
||||
SMOLVLM_START_DOCSTRING,
|
||||
@ -746,18 +676,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||
return merged_embeds
|
||||
|
||||
@add_start_docstrings_to_model_forward(
|
||||
"""
|
||||
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
|
||||
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
|
||||
max_num_images is the maximum number of images among the batch_size samples in the batch.
|
||||
Padding images are not needed beyond padding the pixel_values at the entrance of the model.
|
||||
For efficiency, we only pass through the vision_model's forward the real images by
|
||||
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
|
||||
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
|
||||
""",
|
||||
SMOLVLM_INPUTS_DOCSTRING,
|
||||
)
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -773,6 +692,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -873,13 +793,11 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
|
||||
|
||||
return SmolVLMBaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -927,6 +845,83 @@ class SmolVLMCausalLMOutputWithPast(ModelOutput):
|
||||
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
SMOLVLM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
|
||||
[`CLIPImageProcessor`] for processing images).
|
||||
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
|
||||
Mask to avoid performing attention on padding pixel indices.
|
||||
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The hidden states of the image encoder after modality projection.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The SmolVLM Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
|
||||
SMOLVLM_START_DOCSTRING,
|
||||
@ -979,6 +974,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(SMOLVLM_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=SmolVLMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -998,6 +994,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, SmolVLMCausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@ -1066,7 +1063,8 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -1076,26 +1074,9 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
|
||||
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return SmolVLMCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -20,7 +20,10 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
||||
@ -195,6 +198,7 @@ class SmolVLMModel(Idefics3Model):
|
||||
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||
return merged_embeds
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -210,6 +214,7 @@ class SmolVLMModel(Idefics3Model):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -310,13 +315,11 @@ class SmolVLMModel(Idefics3Model):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
|
||||
|
||||
return SmolVLMBaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
|
@ -23,9 +23,12 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -181,6 +184,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
@ -387,6 +391,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
|
||||
return video_features, num_frames
|
||||
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -404,7 +409,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, VideoLlavaModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -475,10 +480,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = VideoLlavaModelOutputWithPast(
|
||||
return VideoLlavaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
@ -486,7 +491,9 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
image_hidden_states=image_features if pixel_values_images is not None else None,
|
||||
video_hidden_states=video_features if pixel_values_videos is not None else None,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -559,7 +566,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -671,7 +678,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**lm_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -681,7 +688,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return VideoLlavaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -171,6 +171,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of VipLlava isn't meant for training from scratch - only
|
||||
|
@ -461,6 +461,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
|
||||
@require_torch
|
||||
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
additional_model_inputs = ["input_ids"]
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
@ -526,15 +527,11 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
|
||||
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
|
||||
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
|
||||
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.qformer.config._attn_implementation == "eager")
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
@ -545,20 +542,13 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
if (
|
||||
class_name.endswith("Attention")
|
||||
and getattr(submodule, "config", None)
|
||||
and submodule.config._attn_implementation == "sdpa"
|
||||
):
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and any(
|
||||
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
|
||||
):
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -869,6 +859,7 @@ class Blip2ModelTester:
|
||||
@require_torch
|
||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||
additional_model_inputs = ["input_ids", "decoder_input_ids"]
|
||||
# Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration
|
||||
all_generative_model_classes = ()
|
||||
pipeline_model_mapping = (
|
||||
@ -967,15 +958,11 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
|
||||
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
|
||||
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
|
||||
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == "eager")
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.qformer.config._attn_implementation == "eager")
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
@ -986,20 +973,13 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
if (
|
||||
class_name.endswith("Attention")
|
||||
and getattr(submodule, "config", None)
|
||||
and submodule.config._attn_implementation == "sdpa"
|
||||
):
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and any(
|
||||
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
|
||||
):
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -1485,6 +1465,7 @@ class Blip2TextRetrievalModelTester:
|
||||
@require_torch
|
||||
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
|
||||
additional_model_inputs = ["input_ids"]
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -475,6 +475,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
|
||||
additional_model_inputs = ["qformer_input_ids", "input_ids"]
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
@ -687,15 +688,11 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
|
||||
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
|
||||
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
|
||||
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.qformer.config._attn_implementation == "eager")
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
@ -706,20 +703,13 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
if (
|
||||
class_name.endswith("Attention")
|
||||
and getattr(submodule, "config", None)
|
||||
and submodule.config._attn_implementation == "sdpa"
|
||||
):
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and any(
|
||||
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
|
||||
):
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
@ -492,6 +492,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
all_model_classes = (
|
||||
(InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else ()
|
||||
)
|
||||
additional_model_inputs = ["qformer_input_ids", "input_ids"]
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
@ -702,15 +703,11 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
|
||||
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
|
||||
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
|
||||
|
||||
# `None` as it is the requested one which will be assigned to each sub-config
|
||||
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
|
||||
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
|
||||
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model.qformer.config._attn_implementation == "eager")
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
@ -721,20 +718,13 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
if (
|
||||
class_name.endswith("Attention")
|
||||
and getattr(submodule, "config", None)
|
||||
and submodule.config._attn_implementation == "sdpa"
|
||||
):
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and any(
|
||||
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
|
||||
):
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_video():
|
||||
|
@ -30,6 +30,7 @@ from transformers.testing_utils import (
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
@ -42,6 +43,7 @@ from transformers.utils import (
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
@ -259,6 +261,7 @@ class Kosmos2ModelTester:
|
||||
@require_torch
|
||||
class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
|
||||
additional_model_inputs = ["input_ids", "image_embeds_position_mask"]
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Kosmos2Model,
|
||||
@ -462,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@unittest.skip("KOSMOS-2 doesn't support padding")
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask
|
||||
|
@ -219,9 +219,10 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
else {}
|
||||
)
|
||||
is_encoder_decoder = False
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_head_masking = False # new attn API doesn't support head mask
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
|
@ -109,6 +109,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
|
||||
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
|
||||
additional_model_inputs = ["image_sizes"]
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
|
@ -3765,6 +3765,10 @@ class ModelTesterMixin:
|
||||
key = "decoder_hidden_states"
|
||||
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
|
||||
key = "logits"
|
||||
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
|
||||
outputs_eager = outputs_eager["language_model_outputs"]
|
||||
outputs_sdpa = outputs_sdpa["language_model_outputs"]
|
||||
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
|
||||
else:
|
||||
key = "hidden_states"
|
||||
|
||||
@ -4002,14 +4006,14 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa2 = [
|
||||
module._supports_flash_attn_2
|
||||
(module._supports_flash_attn_2 or module._supports_attention_backend)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa2_all_modules = (
|
||||
all(sub_models_supporting_fa2)
|
||||
if len(sub_models_supporting_fa2) > 0
|
||||
else model._supports_flash_attn_2
|
||||
else (model._supports_flash_attn_2 or model._supports_attention_backend)
|
||||
)
|
||||
if not supports_fa2_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user