[VLMs] support attention backends (#37576)

* update models

* why rename

* return attn weights when sdpa

* fixes

* fix attn implementation composite

* fix moshi

* add message

* add typings

* use explicitly all flags for each attn type

* fix some tests

* import what is needed

* kosmos on main has ew attention already, yay

* new models in main, run fixup

* won't fix kosmos yet

* fix-copies

* clean up after rebasing

* fix tests

* style

* dont cast attns to fp32

* did we update ruff? oke, let's just do what it asks

* fix pixtral after rebase
This commit is contained in:
Raushan Turganbay 2025-05-08 18:18:54 +02:00 committed by GitHub
parent e296c63cd4
commit d23aae2b8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 1318 additions and 1555 deletions

View File

@ -833,6 +833,7 @@ class PretrainedConfig(PushToHubMixin):
if "model_type" in value:
# Needs to be set even if it's not in the diff
diff["model_type"] = value["model_type"]
serializable_config_dict[key] = diff
elif (
key not in default_config_dict
@ -1003,6 +1004,8 @@ class PretrainedConfig(PushToHubMixin):
del d["_commit_hash"]
if "_attn_implementation_internal" in d:
del d["_attn_implementation_internal"]
if "_attn_implementation_autoset" in d:
del d["_attn_implementation_autoset"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in d:
del d["base_model_tp_plan"]

View File

@ -3430,9 +3430,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# Unset attn implementation so it can be set to another one when loading back
model_to_save.config._attn_implementation_autoset = False
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:

View File

@ -669,6 +669,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = False
_supports_sdpa = True
_supports_cache_class = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
@ -1409,6 +1410,7 @@ class AriaModel(AriaPreTrainedModel):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
def forward(
self,
@ -1424,6 +1426,7 @@ class AriaModel(AriaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1470,16 +1473,16 @@ class AriaModel(AriaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = AriaModelOutputWithPast(
return AriaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
@ -1563,7 +1566,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1645,6 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -1655,7 +1659,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return AriaCausalLMOutputWithPast(

View File

@ -32,6 +32,7 @@ from ...image_utils import (
valid_images,
validate_preprocess_arguments,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
@ -40,6 +41,7 @@ from ...tokenization_utils import (
TextInput,
)
from ...utils import (
LossKwargs,
TensorType,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -1240,6 +1242,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = False
_supports_sdpa = True
_supports_cache_class = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
@ -1290,6 +1293,9 @@ class AriaTextModel(LlamaModel):
self.post_init()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
"""
Aria model for causal language modeling tasks.
@ -1434,6 +1440,7 @@ class AriaModel(LlavaModel):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
def forward(
self,
@ -1449,6 +1456,7 @@ class AriaModel(LlavaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1495,16 +1503,16 @@ class AriaModel(LlavaModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = AriaModelOutputWithPast(
return AriaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
@ -1533,7 +1541,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1615,6 +1623,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -1625,7 +1634,7 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return AriaCausalLMOutputWithPast(

View File

@ -27,9 +27,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -124,6 +127,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = False
_supports_static_cache = False
_supports_attention_backend = True
def _init_weights(self, module):
std = (
@ -358,6 +362,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
def forward(
self,
@ -375,7 +380,7 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, AyaVisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -434,17 +439,19 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = AyaVisionModelOutputWithPast(
return AyaVisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
@ -512,7 +519,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -589,7 +596,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -599,7 +606,9 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return AyaVisionCausalLMOutputWithPast(
loss=loss,

View File

@ -20,6 +20,7 @@ import torch
from torch import nn
from transformers.models.llava.modeling_llava import (
KwargsForCausalLM,
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
@ -27,6 +28,7 @@ from transformers.models.llava.modeling_llava import (
)
from ...activations import ACT2FN
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
logging,
@ -148,7 +150,7 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, AyaVisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -213,7 +215,7 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)

View File

@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -25,15 +25,18 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
LossKwargs,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -255,6 +258,30 @@ class Blip2VisionEmbeddings(nn.Module):
return embeddings
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> BLIP doesn't cast attn weights to fp32
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Blip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -270,7 +297,8 @@ class Blip2Attention(nn.Module):
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = nn.Dropout(config.attention_dropout)
self.is_causal = False
self.attention_dropout = config.attention_dropout
# small tweak here compared to CLIP, no bias here
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
@ -296,6 +324,7 @@ class Blip2Attention(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -308,31 +337,32 @@ class Blip2Attention(nn.Module):
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_interface: Callable = eager_attention_forward
attention_scores = attention_scores * self.scale
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
**kwargs,
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.projection(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.projection(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None)
return outputs
@ -410,6 +440,10 @@ class Blip2PreTrainedModel(PreTrainedModel):
config_class = Blip2Config
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_no_split_modules = [
"Blip2Attention",
@ -1332,6 +1366,11 @@ class Blip2TextEmbeddings(nn.Module):
BLIP_2_QFORMER_START_DOCSTRING,
)
class Blip2QFormerModel(Blip2PreTrainedModel):
_supports_attention_backend = False # adds position on attn weights before last matmul
_supports_flash_attn_2 = False
_supports_sdpa = False
_supports_flex_attn = False
def __init__(self, config: Blip2QFormerConfig):
super().__init__(config)
self.config = config
@ -1511,6 +1550,9 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""
BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
@ -1526,10 +1568,10 @@ class Blip2Model(Blip2PreTrainedModel):
def __init__(self, config: Blip2Config):
super().__init__(config)
self.vision_model = Blip2VisionModel(config.vision_config)
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config)
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
@ -1580,6 +1622,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
):
r"""
Returns:
@ -1611,6 +1654,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
else:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
@ -1624,6 +1668,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
**kwargs,
)
return text_outputs
@ -1749,6 +1794,7 @@ class Blip2Model(Blip2PreTrainedModel):
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
r"""
Returns:
@ -1826,6 +1872,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
@ -1851,6 +1898,7 @@ class Blip2Model(Blip2PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
**kwargs,
)
loss = outputs.loss
logits = outputs.logits
@ -1981,10 +2029,10 @@ class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
def __init__(self, config: Blip2Config):
super().__init__(config)
self.vision_model = Blip2VisionModel(config.vision_config)
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config)
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
# vision projection layer
self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)
@ -2102,10 +2150,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
def __init__(self, config: Blip2Config):
super().__init__(config)
self.vision_model = Blip2VisionModel(config.vision_config)
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config)
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
@ -2180,6 +2228,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
r"""
Returns:
@ -2308,6 +2357,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
@ -2334,6 +2384,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
use_cache=use_cache,
**kwargs,
)
loss = outputs.loss
logits = outputs.logits
@ -2463,12 +2514,12 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
def __init__(self, config: Blip2Config):
super().__init__(config)
self.vision_model = Blip2VisionModel(config.vision_config)
self.vision_model = Blip2VisionModel._from_config(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
self.qformer = Blip2QFormerModel(config.qformer_config)
self.qformer = Blip2QFormerModel._from_config(config.qformer_config)
# vision projection layer
self.vision_projection = nn.Linear(config.qformer_config.hidden_size, config.image_text_hidden_size)

View File

@ -14,31 +14,32 @@
# limitations under the License.
"""PyTorch Chameleon model."""
import math
from functools import cached_property
from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
@ -235,6 +236,33 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ChameleonAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -259,6 +287,7 @@ class ChameleonAttention(nn.Module):
self.rope_theta = config.rope_theta
self.is_causal = True
self.model_parallel_size = config.model_parallel_size
self.scaling = self.head_dim**-0.5
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
@ -338,144 +367,26 @@ class ChameleonAttention(nn.Module):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attention_interface: Callable = eager_attention_forward
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonFlashAttention2(ChameleonAttention):
"""
Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
key_states = self.k_norm(key_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (ChameleonRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@ -487,114 +398,13 @@ class ChameleonFlashAttention2(ChameleonAttention):
return attn_output, attn_weights, past_key_value
class ChameleonSdpaAttention(ChameleonAttention):
"""
Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from ChameleonAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
key_states = self.k_norm(key_states)
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
CHAMELEON_ATTENTION_CLASSES = {
"eager": ChameleonAttention,
"flash_attention_2": ChameleonFlashAttention2,
"sdpa": ChameleonSdpaAttention,
}
# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
# TODO(joao): add me back asap :)
class ChameleonDecoderLayer(nn.Module):
def __init__(self, config: ChameleonConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
self.mlp = ChameleonMLP(config)
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -669,7 +479,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
self.mlp = ChameleonMLP(config)
self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -1052,6 +862,7 @@ class ChameleonPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_static_cache = True
_supports_param_buffer_assignment = False
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
@ -1256,6 +1067,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1342,6 +1154,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -1498,6 +1311,9 @@ class ChameleonModel(ChameleonPreTrainedModel):
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"Chameleon Model with a head on top used for outputting logits for next token prediction.",
CHAMELEON_START_DOCSTRING,
@ -1532,6 +1348,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
def get_decoder(self):
return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1548,6 +1365,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1596,6 +1414,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -1607,22 +1426,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -1010,6 +1010,7 @@ class Emu3VQVAE(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_attention_backend = True
_no_split_modules = [
"Emu3VQVAETemporalResnetBlock",
"Emu3VQVAEAttentionBlock",
@ -1202,6 +1203,7 @@ class Emu3PreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_static_cache = True
_supports_param_buffer_assignment = False
_supports_attention_backend = True
_supports_flex_attn = True
def _init_weights(self, module):
@ -1836,6 +1838,7 @@ class Emu3Model(Emu3PreTrainedModel):
image = self.vqmodel.decode(image_tokens)
return image
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
def forward(
self,
@ -1851,6 +1854,7 @@ class Emu3Model(Emu3PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1884,8 +1888,9 @@ class Emu3Model(Emu3PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return outputs
@ -1941,6 +1946,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -2007,8 +2013,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -2018,7 +2025,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -25,10 +25,12 @@ import torch.utils.checkpoint
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
CausalLMOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -41,6 +43,7 @@ from ..chameleon.modeling_chameleon import (
ChameleonVQVAEEncoderConvDownsample,
)
from ..llama.modeling_llama import (
KwargsForCausalLM,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
@ -736,6 +739,7 @@ class Emu3VQVAE(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_attention_backend = True
_no_split_modules = [
"Emu3VQVAETemporalResnetBlock",
"Emu3VQVAEAttentionBlock",
@ -898,6 +902,7 @@ class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
"Emu3DecoderLayer",
]
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
@ -1179,6 +1184,7 @@ class Emu3Model(Emu3PreTrainedModel):
image = self.vqmodel.decode(image_tokens)
return image
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
def forward(
self,
@ -1194,6 +1200,7 @@ class Emu3Model(Emu3PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1227,8 +1234,9 @@ class Emu3Model(Emu3PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return outputs
@ -1284,6 +1292,7 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1350,8 +1359,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -1361,7 +1371,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -21,10 +21,19 @@ import torch.utils.checkpoint
from torch import nn
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...models.auto.modeling_auto import AutoModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
from .configuration_fuyu import FuyuConfig
@ -58,6 +67,10 @@ class FuyuPreTrainedModel(PreTrainedModel):
config_class = FuyuConfig
base_model_prefix = "fuyu"
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
@ -142,6 +155,9 @@ FUYU_INPUTS_DOCSTRING = r"""
"""
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""",
FUYU_START_DOCSTRING,
@ -224,8 +240,8 @@ class FuyuModel(FuyuPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -323,6 +339,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
@can_return_tuple
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -392,7 +409,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_dict=return_dict,
return_dict=True,
# don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
)
@ -407,10 +424,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,

View File

@ -30,9 +30,12 @@ import torch.nn.functional as F
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -619,6 +622,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -745,6 +749,7 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
def forward(
self,
@ -759,6 +764,7 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -800,16 +806,19 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = GotOcr2ModelOutputWithPast(
return GotOcr2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
@ -874,6 +883,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -937,6 +947,8 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
hidden_states = outputs[0]
@ -946,7 +958,9 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return GotOcr2CausalLMOutputWithPast(
loss=loss,

View File

@ -21,6 +21,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from transformers.models.llava.modeling_llava import (
KwargsForCausalLM,
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
@ -30,6 +31,8 @@ from transformers.models.llava.modeling_llava import (
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -393,6 +396,7 @@ class GotOcr2Model(LlavaModel):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
def forward(
self,
@ -407,6 +411,7 @@ class GotOcr2Model(LlavaModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -448,16 +453,16 @@ class GotOcr2Model(LlavaModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = GotOcr2ModelOutputWithPast(
return GotOcr2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
@ -479,6 +484,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -542,6 +548,8 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
hidden_states = outputs[0]
@ -551,7 +559,9 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return GotOcr2CausalLMOutputWithPast(
loss=loss,

View File

@ -20,24 +20,27 @@
"""PyTorch Idefics model."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig, PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@ -500,6 +503,30 @@ class IdeficsMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# this was adapted from LlamaAttention
class IdeficsAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -515,11 +542,13 @@ class IdeficsAttention(nn.Module):
layer_idx: Optional[int] = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.dropout = dropout
self.is_causal = True
self.scaling = self.head_dim**-0.5
self.layer_idx = layer_idx
if layer_idx is None:
@ -596,6 +625,7 @@ class IdeficsAttention(nn.Module):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# if key_value_states are provided this layer is used as a cross-attention layer
is_cross_attention = self.is_cross_attention or key_value_states is not None
@ -631,47 +661,33 @@ class IdeficsAttention(nn.Module):
query_states = self.q_layer_norm(query_states)
key_states = self.k_layer_norm(key_states)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
attention_interface: Callable = eager_attention_forward
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
**kwargs,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
attn_weights = None
if output_attentions:
logger.warning_once(
"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
)
attn_weights = None
return attn_output, attn_weights, past_key_value
@ -706,6 +722,7 @@ class IdeficsDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -734,6 +751,7 @@ class IdeficsDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@ -833,6 +851,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -875,6 +894,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
key_value_states=image_hidden_states,
attention_mask=image_attention_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# Fill in zeros for cross_attention hidden_states of tokens attending to no images
@ -927,7 +947,9 @@ class IdeficsPreTrainedModel(PreTrainedModel):
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
_supports_sdpa = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of Idefics isn't meant for training from scratch - only
@ -1029,6 +1051,9 @@ LLAMA_INPUTS_DOCSTRING = r"""
"""
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
@ -1112,6 +1137,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
@ -1130,6 +1156,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
device = input_ids.device if input_ids is not None else inputs_embeds.device
@ -1292,6 +1319,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
past_key_value=None, # not implemented
**kwargs,
)
hidden_states = outputs[0]
@ -1303,6 +1331,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
return layer_outputs
@ -1348,6 +1377,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
cross_layer_interval=self.cross_layer_interval,
gated_cross_attn_layers=self.gated_cross_attn_layers,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -1368,12 +1398,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
if v is not None
)
return IdeficsBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@ -1565,6 +1590,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
):
output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1585,6 +1611,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1641,8 +1668,9 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -1650,24 +1678,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return IdeficsCausalLMOutputWithPast(
loss=loss,

View File

@ -20,17 +20,20 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
@ -514,6 +517,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -1089,6 +1093,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
return new_inputs_embeds
@can_return_tuple
@add_start_docstrings_to_model_forward(
"""
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
@ -1117,6 +1122,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, Idefics2BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1226,15 +1232,13 @@ class Idefics2Model(Idefics2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
if return_legacy_cache and use_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
return Idefics2BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
@ -1244,6 +1248,9 @@ class Idefics2Model(Idefics2PreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
IDEFICS2_START_DOCSTRING,
@ -1292,6 +1299,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@can_return_tuple
@add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1311,6 +1319,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1386,7 +1395,8 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
hidden_states = outputs[0]
@ -1396,26 +1406,9 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return Idefics2CausalLMOutputWithPast(
loss=loss,

View File

@ -20,17 +20,20 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
@ -532,6 +535,7 @@ class Idefics3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -816,6 +820,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
return new_inputs_embeds
@can_return_tuple
@add_start_docstrings_to_model_forward(
"""
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
@ -843,6 +848,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -939,12 +945,10 @@ class Idefics3Model(Idefics3PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
return Idefics3BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
@ -954,6 +958,9 @@ class Idefics3Model(Idefics3PreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
IDEFICS3_START_DOCSTRING,
@ -1009,6 +1016,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@can_return_tuple
@add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1028,6 +1036,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1117,7 +1126,8 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
hidden_states = outputs[0]
@ -1127,26 +1137,9 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return Idefics3CausalLMOutputWithPast(
loss=loss,

View File

@ -16,27 +16,30 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
LossKwargs,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
torch_int,
@ -159,6 +162,30 @@ class InstructBlipVisionEmbeddings(nn.Module):
return embeddings
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
class InstructBlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -175,7 +202,8 @@ class InstructBlipAttention(nn.Module):
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = nn.Dropout(config.attention_dropout)
self.is_causal = False
self.attention_dropout = config.attention_dropout
# small tweak here compared to CLIP, no bias here
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
@ -201,6 +229,7 @@ class InstructBlipAttention(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -213,31 +242,32 @@ class InstructBlipAttention(nn.Module):
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_interface: Callable = eager_attention_forward
attention_scores = attention_scores * self.scale
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
**kwargs,
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.projection(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.projection(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None)
return outputs
@ -315,6 +345,10 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
config_class = InstructBlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
@ -1087,6 +1121,11 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
instruction as input.
"""
_supports_attention_backend = False # adds position on attn weights before last matmul
_supports_flash_attn_2 = False
_supports_sdpa = False
_supports_flex_attn = False
def __init__(self, config: InstructBlipQFormerConfig):
super().__init__(config)
self.config = config
@ -1277,6 +1316,9 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""
InstructBLIP base Model consisting of language model, qformer and vision encoder.
@ -1337,6 +1379,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@can_return_tuple
@add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
def forward(
self,
@ -1352,6 +1395,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1404,6 +1448,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
else:
outputs = self.language_model(
@ -1415,11 +1460,9 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
@ -1448,10 +1491,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
def __init__(self, config: InstructBlipConfig):
super().__init__(config)
self.vision_model = InstructBlipVisionModel(config.vision_config)
self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = InstructBlipQFormerModel(config.qformer_config)
self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
@ -1516,6 +1559,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@can_return_tuple
@add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=InstructBlipForConditionalGenerationModelOutput, config_class=InstructBlipVisionConfig
@ -1535,6 +1579,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -1646,21 +1691,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(logits.device)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="mean")
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
@ -1672,14 +1711,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
**kwargs,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return InstructBlipForConditionalGenerationModelOutput(
loss=loss,
logits=logits,

View File

@ -21,26 +21,29 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
LossKwargs,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
torch_int,
@ -130,6 +133,30 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
return embeddings
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class InstructBlipVideoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -145,7 +172,8 @@ class InstructBlipVideoAttention(nn.Module):
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = nn.Dropout(config.attention_dropout)
self.is_causal = False
self.attention_dropout = config.attention_dropout
# small tweak here compared to CLIP, no bias here
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
@ -171,6 +199,7 @@ class InstructBlipVideoAttention(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -183,31 +212,32 @@ class InstructBlipVideoAttention(nn.Module):
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_interface: Callable = eager_attention_forward
attention_scores = attention_scores * self.scale
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
**kwargs,
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.projection(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.projection(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output, None)
return outputs
@ -852,6 +882,9 @@ class InstructBlipVideoQFormerEmbeddings(nn.Module):
return embeddings
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -945,6 +978,10 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
@ -1049,6 +1086,11 @@ class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
instruction as input.
"""
_supports_attention_backend = False # adds position on attn weights before last matmul
_supports_flash_attn_2 = False
_supports_sdpa = False
_supports_flex_attn = False
def __init__(self, config: InstructBlipVideoQFormerConfig):
super().__init__(config)
self.config = config
@ -1332,6 +1374,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@can_return_tuple
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
def forward(
self,
@ -1347,6 +1390,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1409,6 +1453,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
else:
outputs = self.language_model(
@ -1420,11 +1465,9 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipVideoForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
@ -1453,10 +1496,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
@ -1519,9 +1562,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@can_return_tuple
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoVisionConfig
output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoConfig
)
def forward(
self,
@ -1538,6 +1582,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -1682,21 +1727,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(logits.device)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="mean")
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
@ -1708,14 +1747,11 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
**kwargs,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return InstructBlipVideoForConditionalGenerationModelOutput(
loss=loss,
logits=logits,

View File

@ -18,7 +18,6 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.models.instructblip.configuration_instructblip import (
InstructBlipQFormerConfig,
@ -31,11 +30,14 @@ from transformers.models.instructblip.modeling_instructblip import (
InstructBlipPreTrainedModel,
InstructBlipQFormerModel,
InstructBlipVisionModel,
KwargsForCausalLM,
)
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from ...utils import add_start_docstrings_to_model_forward, logging
from ...processing_utils import Unpack
from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging, replace_return_docstrings
from ..auto import CONFIG_MAPPING, AutoConfig
@ -196,6 +198,7 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None
class InstructBlipVideoModel(InstructBlipModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
def forward(
self,
@ -211,6 +214,7 @@ class InstructBlipVideoModel(InstructBlipModel):
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -273,6 +277,7 @@ class InstructBlipVideoModel(InstructBlipModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
else:
outputs = self.language_model(
@ -284,11 +289,9 @@ class InstructBlipVideoModel(InstructBlipModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipVideoForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
@ -297,6 +300,11 @@ class InstructBlipVideoModel(InstructBlipModel):
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoConfig
)
def forward(
self,
pixel_values: torch.FloatTensor,
@ -312,6 +320,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
```python
@ -447,21 +456,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
**kwargs,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(logits.device)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="mean")
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
@ -473,14 +476,11 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
**kwargs,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return InstructBlipVideoForConditionalGenerationModelOutput(
loss=loss,
logits=logits,

View File

@ -35,6 +35,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseMo
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
@ -621,6 +622,7 @@ class InternVLPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -828,6 +830,7 @@ class InternVLModel(InternVLPreTrainedModel):
return vision_features
@can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
def forward(
self,
@ -845,7 +848,7 @@ class InternVLModel(InternVLPreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, InternVLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -904,17 +907,16 @@ class InternVLModel(InternVLPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = InternVLModelOutputWithPast(
return InternVLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
@ -992,6 +994,9 @@ class InternVLCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The INTERNVL model which consists of a vision backbone and a language model.""",
INTERNVL_START_DOCSTRING,
@ -1056,8 +1061,8 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
image_sizes: Optional[torch.Tensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
r"""
Args:
@ -1138,7 +1143,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -1148,7 +1153,9 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return InternVLCausalLMOutputWithPast(
loss=loss,

View File

@ -21,10 +21,10 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@ -32,10 +32,13 @@ from ...modeling_outputs import (
CausalLMOutputWithCrossAttentions,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
torch_int,
@ -466,7 +469,7 @@ class Kosmos2VisionEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
# Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
@ -481,7 +484,7 @@ def eager_attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
@ -892,6 +895,7 @@ class KosmosTextAttention(nn.Module):
bias: bool = True,
):
super().__init__()
self.config = config
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
@ -929,6 +933,7 @@ class KosmosTextAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@ -953,8 +958,7 @@ class KosmosTextAttention(nn.Module):
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
query_states = self._shape(self.q_proj(hidden_states) * self.scaling)
attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2))
query_states = self._shape(self.q_proj(hidden_states))
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@ -966,32 +970,33 @@ class KosmosTextAttention(nn.Module):
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
src_len = key_states.size(2)
attention_interface: Callable = eager_attention_forward
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, seq_length, src_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, seq_length, src_len)}, but is {attention_mask.size()}"
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = attn_weights + attention_mask
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# attn_output = torch.bmm(attn_probs, value_states) ?
context_states = torch.matmul(attn_weights, value_states)
# attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ?
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
if self.inner_attn_ln is not None:
context_states = self.inner_attn_ln(context_states)
attn_output = self.inner_attn_ln(attn_output)
attn_output = self.out_proj(context_states)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
@ -1060,6 +1065,7 @@ class Kosmos2TextBlock(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
@ -1076,6 +1082,7 @@ class Kosmos2TextBlock(nn.Module):
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@ -1103,6 +1110,7 @@ class Kosmos2TextBlock(nn.Module):
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@ -1216,6 +1224,7 @@ class Kosmos2TextTransformer(nn.Module):
return hidden_states
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
@ -1233,6 +1242,7 @@ class Kosmos2TextTransformer(nn.Module):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1338,6 +1348,7 @@ class Kosmos2TextTransformer(nn.Module):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -1357,18 +1368,6 @@ class Kosmos2TextTransformer(nn.Module):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
@ -1387,6 +1386,9 @@ class Kosmos2PreTrainedModel(PreTrainedModel):
config_class = Kosmos2Config
supports_gradient_checkpointing = True
_no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -1525,6 +1527,7 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
def forward(
@ -1544,6 +1547,7 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
Returns:
@ -1565,9 +1569,13 @@ class Kosmos2TextModel(Kosmos2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""
The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
@ -1600,6 +1608,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@can_return_tuple
@add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2TextConfig)
def forward(
@ -1620,6 +1629,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1652,27 +1662,14 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
lm_logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithCrossAttentions(
loss=loss,
@ -1804,6 +1801,7 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
def set_input_embeddings(self, value):
self.text_model.model.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1822,6 +1820,7 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, Kosmos2ModelOutput]:
r"""
Returns:
@ -1893,13 +1892,10 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
if not return_dict:
outputs = outputs + (image_embeds, projection_attentions, vision_model_output)
return tuple(output for output in outputs if output is not None)
return Kosmos2ModelOutput(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
@ -1946,6 +1942,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.text_model.set_output_embeddings(new_embeddings)
@can_return_tuple
@add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1964,6 +1961,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Kosmos2ForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -2048,13 +2046,10 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
if not return_dict:
outputs = lm_outputs + (image_embeds, projection_attentions, vision_model_output)
return tuple(output for output in outputs if output is not None)
return Kosmos2ForConditionalGenerationModelOutput(
loss=lm_outputs.loss,
logits=lm_outputs.logits,

View File

@ -39,8 +39,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@ -244,6 +246,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
@ -256,12 +259,13 @@ def eager_attention_forward(
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
@ -607,6 +611,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
def forward(
self,
@ -712,13 +717,12 @@ class Llama4TextModel(Llama4PreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
@torch.compiler.disable(recursive=False) # the operations in this method are not compilable
def _update_causal_mask(
@ -931,6 +935,9 @@ class Llama4TextModel(Llama4PreTrainedModel):
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
_no_split_modules = ["Llama4TextDecoderLayer"]
base_model_prefix = "language_model"
@ -965,6 +972,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -981,7 +989,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1031,7 +1039,7 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
@ -1044,10 +1052,6 @@ class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
@ -1208,6 +1212,7 @@ class Llama4VisionAttention(nn.Module):
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = 1
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
@ -1593,7 +1598,6 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
_tp_plan = {}
base_model_prefix = ""
config_class = Llama4Config
_supports_flex_attn = True
def __init__(self, config: Llama4Config):
super().__init__(config)
@ -1673,7 +1677,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1780,7 +1784,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
**kwargs,
)
logits = outputs[0]

View File

@ -23,9 +23,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -171,6 +174,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
@ -331,6 +335,7 @@ class LlavaModel(LlavaPreTrainedModel):
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
def forward(
self,
@ -348,7 +353,7 @@ class LlavaModel(LlavaPreTrainedModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, LlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -407,17 +412,19 @@ class LlavaModel(LlavaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = LlavaModelOutputWithPast(
return LlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
@ -484,8 +491,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
image_sizes: Optional[torch.Tensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -553,7 +560,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -563,7 +570,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return LlavaCausalLMOutputWithPast(
loss=loss,

View File

@ -26,9 +26,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -280,6 +283,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -528,6 +532,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
def forward(
self,
@ -545,7 +550,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, LlavaNextModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -609,17 +614,19 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = LlavaNextModelOutputWithPast(
return LlavaNextModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
@ -688,7 +695,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -756,7 +763,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -766,7 +773,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return LlavaNextCausalLMOutputWithPast(
loss=loss,

View File

@ -30,9 +30,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -223,6 +226,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -581,6 +585,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
def forward(
self,
@ -599,7 +604,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -684,10 +689,10 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = LlavaNextVideoModelOutputWithPast(
return LlavaNextVideoModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
@ -695,7 +700,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
def get_video_features(
self,
@ -744,6 +748,9 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
return video_features
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
@ -811,7 +818,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
@ -915,10 +922,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -928,7 +935,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,

View File

@ -22,6 +22,7 @@ import torch.utils.checkpoint
from torch import nn
from transformers.models.llava_next.modeling_llava_next import (
KwargsForCausalLM,
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
LlavaNextModel,
@ -31,6 +32,8 @@ from transformers.models.llava_next.modeling_llava_next import (
)
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -378,7 +381,7 @@ class LlavaNextVideoModel(LlavaNextModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -463,10 +466,10 @@ class LlavaNextVideoModel(LlavaNextModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = LlavaNextVideoModelOutputWithPast(
return LlavaNextVideoModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
@ -474,7 +477,6 @@ class LlavaNextVideoModel(LlavaNextModel):
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
@ -580,7 +582,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
@ -684,10 +686,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -697,7 +699,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,

View File

@ -30,9 +30,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
can_return_tuple,
is_torchdynamo_compiling,
@ -405,6 +408,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -570,6 +574,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@can_return_tuple
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
@ -590,7 +595,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -681,10 +686,10 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = LlavaOnevisionModelOutputWithPast(
return LlavaOnevisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
@ -693,8 +698,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
def get_video_features(
self,
pixel_values: torch.FloatTensor,
@ -756,6 +759,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
return image_features
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_ONEVISION_START_DOCSTRING,
@ -824,7 +830,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -909,7 +915,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -919,7 +925,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return LlavaOnevisionCausalLMOutputWithPast(
loss=loss,

View File

@ -21,6 +21,7 @@ from torch import nn
from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast
from transformers.models.llava_next_video.modeling_llava_next_video import (
KwargsForCausalLM,
LlavaNextVideoCausalLMOutputWithPast,
LlavaNextVideoForConditionalGeneration,
LlavaNextVideoModel,
@ -36,6 +37,8 @@ from ...image_utils import (
OPENAI_CLIP_STD,
PILImageResampling,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import add_start_docstrings, can_return_tuple, is_torchdynamo_compiling, logging
@ -217,6 +220,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
return video_features
@can_return_tuple
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
@ -237,7 +241,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, LlavaOnevisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -328,10 +332,10 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = LlavaOnevisionModelOutputWithPast(
return LlavaOnevisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
@ -340,8 +344,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGeneration):
@can_return_tuple
@ -367,7 +369,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -452,7 +454,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -462,7 +464,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return LlavaOnevisionCausalLMOutputWithPast(
loss=loss,

View File

@ -28,9 +28,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -233,6 +236,7 @@ class Mistral3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of Mistral3 isn't meant for training from scratch - only
@ -251,6 +255,144 @@ class Mistral3PreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@add_start_docstrings(
"""The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""",
MISTRAL3_START_DOCSTRING,
)
class Mistral3Model(Mistral3PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: Mistral3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = Mistral3MultiModalProjector(config)
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
image_sizes: torch.Tensor,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
return image_features
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
MISTRAL3_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -328,142 +470,6 @@ MISTRAL3_INPUTS_DOCSTRING = r"""
"""
@add_start_docstrings(
"""The Mistral3 model which consists of a vision backbone and a language model, without a language modeling head.""",
MISTRAL3_START_DOCSTRING,
)
class Mistral3Model(Mistral3PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: Mistral3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = Mistral3MultiModalProjector(config)
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
image_sizes: torch.Tensor,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
return image_features
@add_start_docstrings_to_model_forward(MISTRAL3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The MISTRAL3 model which consists of a vision backbone and a language model.""",
MISTRAL3_START_DOCSTRING,
@ -526,8 +532,8 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
image_sizes: Optional[torch.Tensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -585,7 +591,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -595,7 +601,9 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return Mistral3CausalLMOutputWithPast(
loss=loss,

View File

@ -19,6 +19,8 @@ import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -27,6 +29,7 @@ from ...utils import (
replace_return_docstrings,
)
from ..llava.modeling_llava import (
KwargsForCausalLM,
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
@ -174,6 +177,7 @@ class Mistral3Model(LlavaModel):
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
return image_features
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
@ -189,7 +193,7 @@ class Mistral3Model(LlavaModel):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -239,17 +243,16 @@ class Mistral3Model(LlavaModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = Mistral3ModelOutputWithPast(
return Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
@ -271,8 +274,8 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor = None,
**lm_kwargs,
image_sizes: Optional[torch.Tensor] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -330,7 +333,7 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -340,7 +343,9 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return Mistral3CausalLMOutputWithPast(
loss=loss,

View File

@ -15,21 +15,24 @@
"""PyTorch Mllama model."""
import math
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -180,13 +183,56 @@ class MllamaVisionMLP(nn.Module):
return hidden_states
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.llama.modeling_llama.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class MllamaVisionAttention(nn.Module):
def __init__(self, config: MllamaVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.scaling = self.head_dim**-0.5
self.num_key_value_groups = 1
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
@ -198,6 +244,7 @@ class MllamaVisionAttention(nn.Module):
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
query = self.q_proj(hidden_state)
key = self.k_proj(hidden_state)
@ -210,73 +257,35 @@ class MllamaVisionAttention(nn.Module):
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
attention_interface: Callable = eager_attention_forward
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
attn_output, attn_weights = attention_interface(
self,
query,
key,
value,
attention_mask,
dropout=0.0,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return output, attn_weights
class MllamaVisionSdpaAttention(MllamaVisionAttention):
# Adapted from MllamaVisionAttention
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
) -> torch.Tensor:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
if output_attentions:
logger.warning_once(
"MllamaModel is using MllamaVisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_state=hidden_state,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
query = self.q_proj(hidden_state)
key = self.k_proj(hidden_state)
value = self.v_proj(hidden_state)
batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
return output, None
MLLAMA_VISION_ATTENTION_CLASSES = {"eager": MllamaVisionAttention, "sdpa": MllamaVisionSdpaAttention}
return attn_output, attn_weights
class MllamaVisionEncoderLayer(nn.Module):
@ -288,7 +297,7 @@ class MllamaVisionEncoderLayer(nn.Module):
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MLLAMA_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.self_attn = MllamaVisionAttention(config)
self.mlp = MllamaVisionMLP(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
@ -453,6 +462,7 @@ class MllamaTextCrossAttention(nn.Module):
self.head_dim = config.hidden_size // self.num_heads
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@ -471,6 +481,7 @@ class MllamaTextCrossAttention(nn.Module):
output_attentions: bool = False,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
@ -503,17 +514,29 @@ class MllamaTextCrossAttention(nn.Module):
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attention_interface: Callable = eager_attention_forward
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -522,100 +545,6 @@ class MllamaTextCrossAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class MllamaTextCrossSdpaAttention(MllamaTextCrossAttention):
"""
Mllama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`MllamaTextCrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from MllamaTextCrossAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MllamaModel is using MllamaTextCrossSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@ -652,19 +581,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class MllamaTextSelfAttention(nn.Module):
def __init__(self, config: MllamaTextConfig, layer_idx: int):
super().__init__()
@ -675,6 +591,7 @@ class MllamaTextSelfAttention(nn.Module):
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.layer_idx = layer_idx
@ -712,23 +629,29 @@ class MllamaTextSelfAttention(nn.Module):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attention_interface: Callable = eager_attention_forward
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -737,92 +660,6 @@ class MllamaTextSelfAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class MllamaTextSelfSdpaAttention(MllamaTextSelfAttention):
# Adapted from MllamaTextSelfAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
output_attentions: bool = False,
use_cache: bool = False,
past_key_value=None,
cache_position=None,
**kwargs,
):
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"MllamaModel is using MllamaTextSelfSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = {"eager": MllamaTextCrossAttention, "sdpa": MllamaTextCrossSdpaAttention}
MLLAMA_TEXT_ATTENTION_CLASSES = {"eager": MllamaTextSelfAttention, "sdpa": MllamaTextSelfSdpaAttention}
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, config):
@ -847,7 +684,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.self_attn = MllamaTextSelfAttention(config=config, layer_idx=layer_idx)
self.mlp = MllamaTextMLP(config)
self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -868,6 +705,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -905,6 +743,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
@ -931,7 +770,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.cross_attn = MllamaTextCrossAttention(config, layer_idx=layer_idx)
self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
@ -953,6 +792,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -964,6 +804,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
@ -1026,7 +867,9 @@ class MllamaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_static_cache = False # static cache cannot have different shapes for each layer
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_quantized_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -1694,6 +1537,7 @@ class MllamaTextModel(MllamaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
@ -1804,6 +1648,7 @@ class MllamaTextModel(MllamaPreTrainedModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -1832,6 +1677,9 @@ class MllamaTextModel(MllamaPreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
"""The Mllama Text Model with a language modeling head on top.""",
MLLAMA_START_DOCSTRING,
@ -1888,7 +1736,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1945,6 +1793,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -1953,7 +1802,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
@ -1999,6 +1848,7 @@ class MllamaModel(MllamaPreTrainedModel):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
@can_return_tuple
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
def forward(
self,
@ -2017,6 +1867,7 @@ class MllamaModel(MllamaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -2079,15 +1930,15 @@ class MllamaModel(MllamaPreTrainedModel):
output_attentions=output_attentions,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = BaseModelOutputWithPast(
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
@ -2153,7 +2004,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -2220,6 +2071,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@ -2229,7 +2081,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.config.text_config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -236,7 +236,7 @@ class MoshiConfig(PretrainedConfig):
model_type = "moshi"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {"audio_encoder_config": AutoConfig}
sub_configs = {"audio_encoder_config": AutoConfig, "depth_decoder_config": MoshiDepthConfig}
def __init__(
self,

View File

@ -1907,7 +1907,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
self.audio_encoder = AutoModel.from_config(config.audio_encoder_config)
self.decoder = MoshiForCausalLM(config)
self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config)
self.depth_decoder = MoshiDepthDecoder._from_config(config.depth_decoder_config)
self.num_codebooks = config.num_codebooks
self.post_init()

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch OPT model."""
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@ -27,18 +27,21 @@ from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
)
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
replace_return_docstrings,
@ -53,7 +56,7 @@ if is_torch_flex_attn_available():
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
pass
logger = logging.get_logger(__name__)
@ -98,6 +101,30 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
return super().forward(position_ids + self.offset)
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -143,9 +170,8 @@ class OPTAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
# isn't needed in normal attention, but needed in flash attention so to keep the signature same
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
@ -165,206 +191,35 @@ class OPTAttention(nn.Module):
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
attn_weights = torch.matmul(query_states, key_states.transpose(3, 2))
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attention_interface: Callable = eager_attention_forward
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_probs, past_key_value
class OptFlashAttention2(OPTAttention):
"""
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of flash
attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, query_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
attn_dropout = self.dropout if self.training else 0.0
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
position_ids=position_ids,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scaling,
**kwargs,
)
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
attn_output = self.out_proj(attn_weights_reshaped)
if not output_attentions:
attn_weights_reshaped = None
return attn_output, attn_weights_reshaped, past_key_value
class OPTSdpaAttention(OPTAttention):
"""
OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of sdpa
attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or layer_head_mask is not None:
logger.warning_once(
"OPTModel is using SDPA attention, which currently does not support output_attentions=True."
'failing back to eager attention. remove warning using attn_implementation="eager".'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
if not output_attentions:
attn_weights = None
OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
"sdpa": OPTSdpaAttention,
}
return attn_output, attn_weights, past_key_value
class OPTDecoderLayer(nn.Module):
@ -372,7 +227,7 @@ class OPTDecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.self_attn = OPTAttention(config=config, layer_idx=layer_idx)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
@ -395,6 +250,7 @@ class OPTDecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -429,6 +285,7 @@ class OPTDecoderLayer(nn.Module):
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@ -495,8 +352,10 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
@ -763,6 +622,7 @@ class OPTDecoder(OPTPreTrainedModel):
return causal_mask
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -776,6 +636,7 @@ class OPTDecoder(OPTPreTrainedModel):
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
@ -942,6 +803,7 @@ class OPTDecoder(OPTPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
@ -966,8 +828,6 @@ class OPTDecoder(OPTPreTrainedModel):
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@ -996,6 +856,7 @@ class OPTModel(OPTPreTrainedModel):
def get_decoder(self):
return self.decoder
@can_return_tuple
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
@ -1016,6 +877,7 @@ class OPTModel(OPTPreTrainedModel):
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1035,13 +897,11 @@ class OPTModel(OPTPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
@ -1050,6 +910,9 @@ class OPTModel(OPTPreTrainedModel):
)
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@ -1081,6 +944,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.decoder
@can_return_tuple
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
@ -1096,7 +960,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1198,8 +1062,9 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
logits = self.lm_head(outputs[0]).contiguous()
@ -1215,10 +1080,6 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,

View File

@ -23,9 +23,12 @@ from torch import nn
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -159,6 +162,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_supports_static_cache = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
@ -352,6 +356,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
def forward(
self,
@ -368,7 +373,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, PaligemmaModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -436,17 +441,19 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = PaligemmaModelOutputWithPast(
return PaligemmaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
@ -512,7 +519,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -570,7 +577,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -580,7 +587,9 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return PaliGemmaCausalLMOutputWithPast(
loss=loss,

View File

@ -21,14 +21,18 @@ import torch
import torch.utils.checkpoint
from torch import nn
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
from ...modeling_rope_utils import dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
)
from .configuration_pixtral import PixtralVisionConfig
@ -132,7 +136,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
# Copied from transformers.models.smolvlm.modeling_smolvlm.eager_attention_forward
# Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
@ -167,10 +171,11 @@ class PixtralAttention(nn.Module):
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.is_causal = False
self.scale = self.head_dim**-0.5
self.scaling = self.head_dim**-0.5
self.is_causal = False
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@ -211,28 +216,22 @@ class PixtralAttention(nn.Module):
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# Since we use packing, if Flash-Attn 2 is selected we rely on position_ids
if self.config._attn_implementation == "flash_attention_2":
kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True)
attention_mask = None
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
is_causal=self.is_causal,
output_attentions=output_attentions,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, patches, -1)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
@ -288,7 +287,7 @@ class PixtralAttentionLayer(nn.Module):
attention_mask: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: Optional[bool] = None,
**kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor]:
"""
Args:
@ -341,7 +340,7 @@ class PixtralTransformer(nn.Module):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
@ -383,7 +382,6 @@ class PixtralTransformer(nn.Module):
attention_mask,
position_embeddings,
output_attentions,
**kwargs,
)
else:
layer_outputs = encoder_layer(
@ -431,6 +429,10 @@ class PixtralPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_no_split_modules = ["PixtralAttentionLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
@ -508,6 +510,7 @@ class PixtralVisionModel(PixtralPreTrainedModel):
def get_input_embeddings(self):
return self.patch_conv
@can_return_tuple
@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
def forward(
self,
@ -517,7 +520,7 @@ class PixtralVisionModel(PixtralPreTrainedModel):
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args,
**kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutput]:
"""
Returns:
@ -551,17 +554,15 @@ class PixtralVisionModel(PixtralPreTrainedModel):
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)
out = self.transformer(
return self.transformer(
patch_embeds,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
return out
__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]

View File

@ -24,17 +24,20 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...cache_utils import DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
)
@ -78,6 +81,7 @@ class SmolVLMPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_attention_backend = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -574,80 +578,6 @@ class SmolVLMConnector(nn.Module):
return image_hidden_states
SMOLVLM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
[`CLIPImageProcessor`] for processing images).
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The hidden states of the image encoder after modality projection.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"""SmolVLM model consisting of a SIGLIP vision encoder and Llama3 language decoder""",
SMOLVLM_START_DOCSTRING,
@ -746,18 +676,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
return merged_embeds
@add_start_docstrings_to_model_forward(
"""
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
max_num_images is the maximum number of images among the batch_size samples in the batch.
Padding images are not needed beyond padding the pixel_values at the entrance of the model.
For efficiency, we only pass through the vision_model's forward the real images by
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
""",
SMOLVLM_INPUTS_DOCSTRING,
)
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -773,6 +692,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -873,13 +793,11 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
return SmolVLMBaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
@ -927,6 +845,83 @@ class SmolVLMCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
SMOLVLM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
[`CLIPImageProcessor`] for processing images).
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The hidden states of the image encoder after modality projection.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"""The SmolVLM Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
SMOLVLM_START_DOCSTRING,
@ -979,6 +974,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@can_return_tuple
@add_start_docstrings_to_model_forward(SMOLVLM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SmolVLMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -998,6 +994,7 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, SmolVLMCausalLMOutputWithPast]:
r"""
Args:
@ -1066,7 +1063,8 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
hidden_states = outputs[0]
@ -1076,26 +1074,9 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return SmolVLMCausalLMOutputWithPast(
loss=loss,

View File

@ -20,7 +20,10 @@ import torch.utils.checkpoint
from torch import nn
from ...cache_utils import DynamicCache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import (
can_return_tuple,
logging,
)
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
@ -195,6 +198,7 @@ class SmolVLMModel(Idefics3Model):
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
return merged_embeds
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -210,6 +214,7 @@ class SmolVLMModel(Idefics3Model):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, SmolVLMBaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -310,13 +315,11 @@ class SmolVLMModel(Idefics3Model):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
return SmolVLMBaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,

View File

@ -23,9 +23,12 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -181,6 +184,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = (
@ -387,6 +391,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
return video_features, num_frames
@can_return_tuple
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
def forward(
self,
@ -404,7 +409,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, VideoLlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -475,10 +480,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
output = VideoLlavaModelOutputWithPast(
return VideoLlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
@ -486,7 +491,9 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
image_hidden_states=image_features if pixel_values_images is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@add_start_docstrings(
@ -559,7 +566,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -671,7 +678,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
**kwargs,
)
hidden_states = outputs[0]
@ -681,7 +688,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return VideoLlavaCausalLMOutputWithPast(
loss=loss,

View File

@ -171,6 +171,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only

View File

@ -461,6 +461,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
@require_torch
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
additional_model_inputs = ["input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@ -526,15 +527,11 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@ -545,20 +542,13 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -869,6 +859,7 @@ class Blip2ModelTester:
@require_torch
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
additional_model_inputs = ["input_ids", "decoder_input_ids"]
# Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration
all_generative_model_classes = ()
pipeline_model_mapping = (
@ -967,15 +958,11 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "eager")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@ -986,20 +973,13 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -1485,6 +1465,7 @@ class Blip2TextRetrievalModelTester:
@require_torch
class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForImageTextRetrieval,) if is_torch_available() else ()
additional_model_inputs = ["input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False

View File

@ -475,6 +475,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
else ()
)
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
additional_model_inputs = ["qformer_input_ids", "input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@ -687,15 +688,11 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@ -706,20 +703,13 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -492,6 +492,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
all_model_classes = (
(InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel) if is_torch_available() else ()
)
additional_model_inputs = ["qformer_input_ids", "input_ids"]
fx_compatible = False
test_head_masking = False
test_pruning = False
@ -702,15 +703,11 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
text_attn = "sdpa" if model.language_model._supports_sdpa else "eager"
vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager"
qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager"
# `None` as it is the requested one which will be assigned to each sub-config
# Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present)
self.assertTrue(model.language_model.config._attn_implementation == text_attn)
self.assertTrue(model.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model.qformer.config._attn_implementation == qformer_attn)
self.assertTrue(model.language_model.config._attn_implementation == "sdpa")
self.assertTrue(model.vision_model.config._attn_implementation == "sdpa")
self.assertTrue(model.qformer.config._attn_implementation == "eager")
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
@ -721,20 +718,13 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(
module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn]
):
raise ValueError("The SDPA model should have SDPA attention layers")
# We will verify our results on an image of cute cats
def prepare_video():

View File

@ -30,6 +30,7 @@ from transformers.testing_utils import (
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -42,6 +43,7 @@ from transformers.utils import (
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_config_zero_init,
floats_tensor,
@ -259,6 +261,7 @@ class Kosmos2ModelTester:
@require_torch
class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
additional_model_inputs = ["input_ids", "image_embeds_position_mask"]
pipeline_model_mapping = (
{
"feature-extraction": Kosmos2Model,
@ -462,6 +465,14 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds(self):
pass
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@unittest.skip("KOSMOS-2 doesn't support padding")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask

View File

@ -219,9 +219,10 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
else {}
)
is_encoder_decoder = False
fx_compatible = True
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
test_pruning = False
test_missing_keys = False
test_head_masking = False # new attn API doesn't support head mask
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(

View File

@ -109,6 +109,7 @@ class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
additional_model_inputs = ["image_sizes"]
test_pruning = False
test_head_masking = False
test_torchscript = False

View File

@ -3765,6 +3765,10 @@ class ModelTesterMixin:
key = "decoder_hidden_states"
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
key = "logits"
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
outputs_eager = outputs_eager["language_model_outputs"]
outputs_sdpa = outputs_sdpa["language_model_outputs"]
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
else:
key = "hidden_states"
@ -4002,14 +4006,14 @@ class ModelTesterMixin:
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
sub_models_supporting_fa2 = [
module._supports_flash_attn_2
(module._supports_flash_attn_2 or module._supports_attention_backend)
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_fa2_all_modules = (
all(sub_models_supporting_fa2)
if len(sub_models_supporting_fa2) > 0
else model._supports_flash_attn_2
else (model._supports_flash_attn_2 or model._supports_attention_backend)
)
if not supports_fa2_all_modules:
with self.assertRaises(ValueError):