Merge branch 'main' into feature/add-emergency-checkpointing

This commit is contained in:
Ayush Sharma 2025-07-01 09:51:14 -05:00 committed by GitHub
commit ad88a9e60e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
132 changed files with 3038 additions and 4645 deletions

View File

@ -733,7 +733,9 @@ class GenerationMixin(ContinuousMixin):
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs: if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder: if model_kwargs["inputs_embeds"] is None:
model_kwargs.pop("inputs_embeds")
elif not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set( has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys() inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
) )
@ -748,6 +750,7 @@ class GenerationMixin(ContinuousMixin):
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs inputs, bos_token_id, model_kwargs=model_kwargs
) )
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
else: else:
if inputs is not None: if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
import math
import operator import operator
import os import os
import re import re
@ -280,7 +281,48 @@ def repack_weights(
def get_tensor_shard(param, empty_param, device_mesh, rank, dim): def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
""" """
Generalized tensor sharding across a multi-dimensional device mesh. Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
Case (1)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 4
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
Case (2)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 14
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
rank 8 gets (0, 5120, 8190)
rank 9 gets (0, 5120, 8190)
rank 10 gets (0, 5120, 8190)
rank 11 gets (0, 5120, 8190)
rank 12 gets (0, 5120, 8190)
rank 13 gets (0, 5120, 8190)
Case (3)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 3
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
Args: Args:
param (torch.Tensor): The tensor to shard. param (torch.Tensor): The tensor to shard.
empty_param (torch.Tensor): A tensor used for shape reference. empty_param (torch.Tensor): A tensor used for shape reference.
@ -289,6 +331,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
dim (int): Dimension along which to shard the tensor. dim (int): Dimension along which to shard the tensor.
""" """
param_dim = empty_param.dim() param_dim = empty_param.dim()
if dim < 0: if dim < 0:
dim = param_dim + dim dim = param_dim + dim
if dim >= param_dim: if dim >= param_dim:
@ -301,15 +344,18 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if rank >= world_size: if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}") raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
shard_size = empty_param.shape[dim] // world_size shard_size = math.ceil(empty_param.shape[dim] / world_size)
start = rank * shard_size start = rank * shard_size
end = start + shard_size
# Construct slicing index dynamically # Construct slicing index dynamically
end = min(start + shard_size, empty_param.shape[dim])
slice_indices = [slice(None)] * param_dim slice_indices = [slice(None)] * param_dim
if start < empty_param.shape[dim]:
slice_indices[dim] = slice(start, end) slice_indices[dim] = slice(start, end)
return param[tuple(slice_indices)] return param[tuple(slice_indices)]
dimensions = list(param.shape)
dimensions[dim] = 0
return torch.empty(tuple(dimensions), dtype=torch.int64)
def distribute_module( def distribute_module(
@ -500,7 +546,9 @@ class ColwiseParallel(TensorParallelLayer):
if to_contiguous: if to_contiguous:
parameter = parameter.contiguous() parameter = parameter.contiguous()
if self.use_dtensor: if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) parameter = DTensor.from_local(
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod @staticmethod
@ -574,7 +622,9 @@ class RowwiseParallel(TensorParallelLayer):
if to_contiguous: if to_contiguous:
parameter = parameter.contiguous() parameter = parameter.contiguous()
if self.use_dtensor: if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) parameter = DTensor.from_local(
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod @staticmethod

View File

@ -508,6 +508,22 @@ def _flash_attention_forward(
query_states, key_states, value_states, target_dtype query_states, key_states, value_states, target_dtype
) )
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# under two cases:
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
is_fa2_with_position_ids = (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
)
is_fa2_with_varlen_kwargs = all(
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if attention_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
@ -531,14 +547,7 @@ def _flash_attention_forward(
) )
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
):
batch_size = query_states.size(0) batch_size = query_states.size(0)
if cu_seq_lens_q is None or cu_seq_lens_k is None: if cu_seq_lens_q is None or cu_seq_lens_k is None:

View File

@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
module_map[name + f".{key}"] = module module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS): if any(
allowed_name in class_name.__name__.lower()
for class_name in self.__class__.__mro__[:-1]
for allowed_name in VLMS
):
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()} reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {} original_state_dict = {}
@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
key_mapping = kwargs.pop("key_mapping", None) key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS): if key_mapping is None and any(
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
):
key_mapping = cls._checkpoint_conversion_mapping key_mapping = cls._checkpoint_conversion_mapping
# Not used anymore -- remove them from the kwargs # Not used anymore -- remove them from the kwargs
@ -5837,7 +5843,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
else None else None
) )
total_byte_count = defaultdict(lambda: 0) total_byte_count = defaultdict(lambda: 0)
tied_param_names = _get_tied_weight_keys(model)
for param_name, device in accelerator_device_map.items(): for param_name, device in accelerator_device_map.items():
# Skip if the parameter has already been accounted for (tied weights)
if param_name in tied_param_names:
continue
param = model.get_parameter_or_buffer(param_name) param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size() param_byte_count = param.numel() * param.element_size()

View File

@ -16,7 +16,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -25,14 +25,15 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithNoAttention, BaseModelOutputWithNoAttention,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
@ -90,7 +91,7 @@ class AlignOutput(ModelOutput):
The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`]. The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The output of [`AlignVisionModel`]. The output of [`AlignVisionModel`].
text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`): text_model_output (`BaseModelOutputWithPooling`):
The output of the [`AlignTextModel`]. The output of the [`AlignTextModel`].
vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`): vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
The output of the [`AlignVisionModel`]. The output of the [`AlignVisionModel`].
@ -101,7 +102,7 @@ class AlignOutput(ModelOutput):
logits_per_text: Optional[torch.FloatTensor] = None logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
def to_tuple(self) -> tuple[Any]: def to_tuple(self) -> tuple[Any]:
@ -508,7 +509,6 @@ class AlignVisionEncoder(nn.Module):
) )
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText
class AlignTextEmbeddings(nn.Module): class AlignTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
@ -537,7 +537,6 @@ class AlignTextEmbeddings(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
@ -547,7 +546,7 @@ class AlignTextEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -573,9 +572,35 @@ class AlignTextEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class AlignTextSelfAttention(nn.Module): class AlignTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
@ -583,6 +608,7 @@ class AlignTextSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -592,20 +618,12 @@ class AlignTextSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.attention_dropout = config.attention_probs_dropout_prob
config, "position_embedding_type", "absolute" self.scaling = self.attention_head_size**-0.5
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -615,96 +633,33 @@ class AlignTextSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: attention_interface: Callable = eager_attention_forward
# reuse k,v, cross_attentions if self.config._attn_implementation != "eager":
key_layer = past_key_value[0] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) attn_output, attn_weights = attention_interface(
self,
use_cache = past_key_value is not None query_states,
if self.is_decoder: key_states,
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. value_states,
# Further calls to cross_attention layer can then reuse all cross-attention attention_mask,
# key/value_states (first "if" case) dropout=0.0 if not self.training else self.attention_dropout,
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of scaling=self.scaling,
# all previous decoder key/value_states. Further calls to uni-directional self-attention head_mask=head_mask,
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) **kwargs,
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
) )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
@ -723,18 +678,10 @@ class AlignTextSelfOutput(nn.Module):
return hidden_states return hidden_states
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
"eager": AlignTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
class AlignTextAttention(nn.Module): class AlignTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( self.self = AlignTextSelfAttention(config)
config, position_embedding_type=position_embedding_type
)
self.output = AlignTextSelfOutput(config) self.output = AlignTextSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
@ -756,6 +703,9 @@ class AlignTextAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -765,15 +715,14 @@ class AlignTextAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, **kwargs,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -811,22 +760,18 @@ class AlignTextOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
class AlignTextLayer(GradientCheckpointingLayer): class AlignTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = AlignTextAttention(config) self.attention = AlignTextAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = AlignTextAttention(config, position_embedding_type="absolute")
self.intermediate = AlignTextIntermediate(config) self.intermediate = AlignTextIntermediate(config)
self.output = AlignTextOutput(config) self.output = AlignTextOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -836,60 +781,23 @@ class AlignTextLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -898,14 +806,18 @@ class AlignTextLayer(GradientCheckpointingLayer):
return layer_output return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText
class AlignTextEncoder(nn.Module): class AlignTextEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -918,65 +830,36 @@ class AlignTextEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -1052,6 +935,7 @@ class AlignTextModel(AlignPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1059,12 +943,13 @@ class AlignTextModel(AlignPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: **kwargs,
) -> Union[tuple, BaseModelOutputWithPooling]:
r""" r"""
Examples: Examples:
@ -1133,20 +1018,17 @@ class AlignTextModel(AlignPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
**kwargs,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: return BaseModelOutputWithPooling(
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
@ -1180,6 +1062,7 @@ class AlignVisionModel(AlignPreTrainedModel):
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.convolution return self.vision_model.embeddings.convolution
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1219,7 +1102,7 @@ class AlignVisionModel(AlignPreTrainedModel):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
# Apply pooling # Apply pooling
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
@ -1227,9 +1110,6 @@ class AlignVisionModel(AlignPreTrainedModel):
# Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim) # Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
pooled_output = pooled_output.reshape(pooled_output.shape[:2]) pooled_output = pooled_output.reshape(pooled_output.shape[:2])
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndNoAttention( return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooled_output,
@ -1369,6 +1249,7 @@ class AlignModel(AlignPreTrainedModel):
return image_features return image_features
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1419,7 +1300,7 @@ class AlignModel(AlignPreTrainedModel):
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
text_outputs = self.text_model( text_outputs = self.text_model(
@ -1431,7 +1312,7 @@ class AlignModel(AlignPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
image_embeds = vision_outputs[1] image_embeds = vision_outputs[1]
@ -1450,10 +1331,6 @@ class AlignModel(AlignPreTrainedModel):
if return_loss: if return_loss:
loss = align_loss(logits_per_text) loss = align_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return AlignOutput( return AlignOutput(
loss=loss, loss=loss,
logits_per_image=logits_per_image, logits_per_image=logits_per_image,

View File

@ -26,14 +26,14 @@ from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutputWithPoolingAndProjection, BaseModelOutputWithPoolingAndProjection,
) )
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
@ -180,7 +180,6 @@ class AltRobertaEmbeddings(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape) return position_ids.unsqueeze(0).expand(input_shape)
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta
class AltRobertaSelfAttention(nn.Module): class AltRobertaSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
@ -206,13 +205,9 @@ class AltRobertaSelfAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder @deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: @deprecate_kwarg("past_key_value", version="4.54.0")
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -223,54 +218,18 @@ class AltRobertaSelfAttention(nn.Module):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2] query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r distance = position_ids_l - position_ids_r
@ -310,8 +269,6 @@ class AltRobertaSelfAttention(nn.Module):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
@ -335,7 +292,6 @@ ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
} }
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
class AltRobertaAttention(nn.Module): class AltRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
@ -363,6 +319,9 @@ class AltRobertaAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -375,12 +334,9 @@ class AltRobertaAttention(nn.Module):
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -418,22 +374,19 @@ class AltRobertaOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta
class AltRobertaLayer(GradientCheckpointingLayer): class AltRobertaLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = AltRobertaAttention(config) self.attention = AltRobertaAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute")
self.intermediate = AltRobertaIntermediate(config) self.intermediate = AltRobertaIntermediate(config)
self.output = AltRobertaOutput(config) self.output = AltRobertaOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -443,60 +396,23 @@ class AltRobertaLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -505,14 +421,19 @@ class AltRobertaLayer(GradientCheckpointingLayer):
return layer_output return layer_output
# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta
class AltRobertaEncoder(nn.Module): class AltRobertaEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -525,65 +446,36 @@ class AltRobertaEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -787,6 +679,7 @@ class AltCLIPEncoder(nn.Module):
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -853,8 +746,6 @@ class AltCLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )
@ -1008,6 +899,7 @@ class AltCLIPVisionTransformer(nn.Module):
self.encoder = AltCLIPEncoder(config) self.encoder = AltCLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1033,16 +925,13 @@ class AltCLIPVisionTransformer(nn.Module):
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :] pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output) pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooled_output,
@ -1106,16 +995,11 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of The model behaves as an encoder following the architecture described in *Attention is
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin. Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
""" """
) )
class AltRobertaModel(AltCLIPPreTrainedModel): class AltRobertaModel(AltCLIPPreTrainedModel):
@ -1152,6 +1036,10 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@auto_docstring @auto_docstring
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward( def forward(
@ -1176,11 +1064,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -1194,11 +1077,8 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"): if hasattr(self.embeddings, "token_type_ids"):
@ -1212,21 +1092,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
@ -1235,33 +1100,23 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: return BaseModelOutputWithPooling(
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
@ -1284,6 +1139,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
return super().resize_token_embeddings(new_num_tokens) return super().resize_token_embeddings(new_num_tokens)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1326,11 +1184,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
# last module outputs # last module outputs
@ -1343,9 +1199,6 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
projection_state = self.transformation(sequence_output) projection_state = self.transformation(sequence_output)
pooler_output = projection_state[:, 0] pooler_output = projection_state[:, 0]
if not return_dict:
return (projection_state, pooler_output) + outputs[2:4]
return BaseModelOutputWithPoolingAndProjection( return BaseModelOutputWithPoolingAndProjection(
last_hidden_state=projection_state, last_hidden_state=projection_state,
pooler_output=pooler_output, pooler_output=pooler_output,

View File

@ -225,7 +225,7 @@ class ArceeAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -532,7 +532,7 @@ class AriaTextAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -1113,11 +1113,12 @@ class AriaModel(AriaPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
image_embeds = input_ids == self.config.image_token_id special_image_mask = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values=pixel_values, pixel_values=pixel_values,
pixel_mask=pixel_mask, pixel_mask=pixel_mask,

View File

@ -1446,11 +1446,12 @@ class AriaModel(LlavaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
image_embeds = input_ids == self.config.image_token_id special_image_mask = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values=pixel_values, pixel_values=pixel_values,
pixel_mask=pixel_mask, pixel_mask=pixel_mask,

View File

@ -302,14 +302,14 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -223,14 +223,14 @@ class AyaVisionModel(LlavaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -296,7 +296,7 @@ class BambaAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -1855,6 +1855,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens", "qformer"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
@ -1971,10 +1972,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor, input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -2066,14 +2068,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
language_model_attention_mask = torch.ones( language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
@ -2146,6 +2159,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -2159,6 +2173,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices Mask to avoid performing attention on padding token indices
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings.
Returns: Returns:
captions (list): A list of strings of length batch_size * num_captions. captions (list): A list of strings of length batch_size * num_captions.
@ -2193,22 +2211,32 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if inputs_embeds is None:
if input_ids is None: if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id] start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1) input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten() special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. " "Expanding inputs for image tokens in BLIP-2 should be done in processing. "

View File

@ -1026,7 +1026,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@auto_docstring @auto_docstring
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,

View File

@ -26,13 +26,14 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_bros import BrosConfig from .configuration_bros import BrosConfig
@ -150,7 +151,6 @@ class BrosTextEmbeddings(nn.Module):
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
@ -160,7 +160,7 @@ class BrosTextEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None: if token_type_ids is None:
if hasattr(self, "token_type_ids"): if hasattr(self, "token_type_ids"):
@ -208,14 +208,7 @@ class BrosSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor): @deprecate_kwarg("past_key_value", version="4.54.0")
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -227,42 +220,21 @@ class BrosSelfAttention(nn.Module):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[torch.Tensor] = False, output_attentions: Optional[torch.Tensor] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be # and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to. # such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: if is_cross_attention:
# reuse k,v, cross_attentions key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = past_key_value[0] value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else: else:
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.transpose_for_scores(self.value(hidden_states)) value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@ -317,7 +289,7 @@ class BrosSelfAttention(nn.Module):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder: if self.is_decoder:
outputs = outputs + (past_key_value,) outputs = outputs + (None,)
return outputs return outputs
@ -364,6 +336,7 @@ class BrosAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -382,7 +355,6 @@ class BrosAttention(nn.Module):
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
@ -435,6 +407,7 @@ class BrosLayer(GradientCheckpointingLayer):
self.intermediate = BrosIntermediate(config) self.intermediate = BrosIntermediate(config)
self.output = BrosOutput(config) self.output = BrosOutput(config)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -446,50 +419,38 @@ class BrosLayer(GradientCheckpointingLayer):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
bbox_pos_emb=bbox_pos_emb, bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache # if decoder, the last output is tuple of self-attn cache
if self.is_decoder: if self.is_decoder:
outputs = self_attention_outputs[1:-1] outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else: else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
if hasattr(self, "crossattention"): if hasattr(self, "crossattention"):
raise Exception( raise Exception(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
) )
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
cross_attn_past_key_value, output_attentions=output_attentions,
output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.feed_forward_chunk,
self.chunk_size_feed_forward, self.chunk_size_feed_forward,
@ -500,7 +461,7 @@ class BrosLayer(GradientCheckpointingLayer):
# if decoder, return the attn key/values as the last output # if decoder, return the attn key/values as the last output
if self.is_decoder: if self.is_decoder:
outputs = outputs + (present_key_value,) outputs = outputs + (None,)
return outputs return outputs
@ -516,6 +477,9 @@ class BrosEncoder(nn.Module):
self.config = config self.config = config
self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -529,33 +493,28 @@ class BrosEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
bbox_pos_emb, bbox_pos_emb=bbox_pos_emb,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
@ -564,21 +523,8 @@ class BrosEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutputWithCrossAttentions(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
@ -689,6 +635,9 @@ class BrosModel(BrosPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -736,11 +685,6 @@ class BrosModel(BrosPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -756,9 +700,6 @@ class BrosModel(BrosPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
@ -797,7 +738,6 @@ class BrosModel(BrosPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
@ -813,22 +753,16 @@ class BrosModel(BrosPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions, cross_attentions=encoder_outputs.cross_attentions,
@ -852,6 +786,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
self.init_weights() self.init_weights()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -908,7 +843,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -927,10 +862,6 @@ class BrosForTokenClassification(BrosPreTrainedModel):
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
@ -976,6 +907,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
self.init_weights() self.init_weights()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1037,7 +969,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
last_hidden_states = outputs[0] last_hidden_states = outputs[0]
@ -1082,10 +1014,6 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
loss = initial_token_loss + subsequent_token_loss loss = initial_token_loss + subsequent_token_loss
if not return_dict:
output = (initial_token_logits, subsequent_token_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return BrosSpadeOutput( return BrosSpadeOutput(
loss=loss, loss=loss,
initial_token_logits=initial_token_logits, initial_token_logits=initial_token_logits,
@ -1118,6 +1046,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
self.init_weights() self.init_weights()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1173,7 +1102,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
last_hidden_states = outputs[0] last_hidden_states = outputs[0]
@ -1203,10 +1132,6 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask]) loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,

View File

@ -963,25 +963,28 @@ class ChameleonModel(ChameleonPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None: if inputs_embeds is None:
raise ValueError( inputs_embeds = self.embed_tokens(input_ids)
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values) if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() n_image_tokens_in_text = (special_image_mask).sum()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1] special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel():
n_image_features = image_embeds.shape[0] * image_embeds.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
) )
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# torch.jit.trace() doesn't support cache objects in the output # torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing(): if use_cache and past_key_values is None and not torch.jit.is_tracing():

View File

@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch Chinese-CLIP model.""" """PyTorch Chinese-CLIP model."""
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -26,13 +25,13 @@ from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
@ -90,7 +89,7 @@ class ChineseCLIPOutput(ModelOutput):
) )
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText # Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP
class ChineseCLIPTextEmbeddings(nn.Module): class ChineseCLIPTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
@ -119,7 +118,6 @@ class ChineseCLIPTextEmbeddings(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
@ -129,7 +127,7 @@ class ChineseCLIPTextEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -239,9 +237,37 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText # Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP
class ChineseCLIPTextSelfAttention(nn.Module): class ChineseCLIPTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
@ -249,6 +275,7 @@ class ChineseCLIPTextSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -258,20 +285,12 @@ class ChineseCLIPTextSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.attention_dropout = config.attention_probs_dropout_prob
config, "position_embedding_type", "absolute" self.scaling = self.attention_head_size**-0.5
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -281,96 +300,33 @@ class ChineseCLIPTextSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: attention_interface: Callable = eager_attention_forward
# reuse k,v, cross_attentions if self.config._attn_implementation != "eager":
key_layer = past_key_value[0] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) attn_output, attn_weights = attention_interface(
self,
use_cache = past_key_value is not None query_states,
if self.is_decoder: key_states,
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. value_states,
# Further calls to cross_attention layer can then reuse all cross-attention attention_mask,
# key/value_states (first "if" case) dropout=0.0 if not self.training else self.attention_dropout,
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of scaling=self.scaling,
# all previous decoder key/value_states. Further calls to uni-directional self-attention head_mask=head_mask,
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) **kwargs,
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
) )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
@ -389,18 +345,11 @@ class ChineseCLIPTextSelfOutput(nn.Module):
return hidden_states return hidden_states
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = { # Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP
"eager": ChineseCLIPTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
class ChineseCLIPTextAttention(nn.Module): class ChineseCLIPTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( self.self = ChineseCLIPTextSelfAttention(config)
config, position_embedding_type=position_embedding_type
)
self.output = ChineseCLIPTextSelfOutput(config) self.output = ChineseCLIPTextSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
@ -422,6 +371,9 @@ class ChineseCLIPTextAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -431,15 +383,14 @@ class ChineseCLIPTextAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, **kwargs,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -468,66 +419,37 @@ class ChineseCLIPVisionAttention(nn.Module):
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# get query proj query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale
query_states = self.q_proj(hidden_states) * self.scale key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) attention_interface: Callable = eager_attention_forward
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) if self.config._attn_implementation != "eager":
key_states = key_states.view(*proj_shape) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1) attn_output, attn_weights = attention_interface(
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) self,
query_states,
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): key_states,
raise ValueError( value_states,
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" None,
f" {attn_weights.size()}" dropout=0.0 if not self.training else self.dropout,
scaling=1.0,
**kwargs,
) )
attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped return attn_output, attn_weights
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
@ -577,22 +499,19 @@ class ChineseCLIPVisionMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText # Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP
class ChineseCLIPTextLayer(GradientCheckpointingLayer): class ChineseCLIPTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ChineseCLIPTextAttention(config) self.attention = ChineseCLIPTextAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
self.intermediate = ChineseCLIPTextIntermediate(config) self.intermediate = ChineseCLIPTextIntermediate(config)
self.output = ChineseCLIPTextOutput(config) self.output = ChineseCLIPTextOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -602,60 +521,23 @@ class ChineseCLIPTextLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -777,14 +659,19 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
class ChineseCLIPTextEncoder(nn.Module): class ChineseCLIPTextEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -797,65 +684,36 @@ class ChineseCLIPTextEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -874,6 +732,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -922,8 +781,6 @@ class ChineseCLIPVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )
@ -940,6 +797,7 @@ class ChineseCLIPVisionTransformer(nn.Module):
self.encoder = ChineseCLIPVisionEncoder(config) self.encoder = ChineseCLIPVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -965,16 +823,13 @@ class ChineseCLIPVisionTransformer(nn.Module):
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :] pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output) pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooled_output,
@ -1034,6 +889,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1050,18 +906,13 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -1093,56 +944,28 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings( embedding_output = self.embeddings(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: return BaseModelOutputWithPooling(
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
@ -1343,6 +1166,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
return image_features return image_features
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1392,7 +1216,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=True,
) )
text_outputs = self.text_model( text_outputs = self.text_model(
@ -1402,7 +1226,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
image_embeds = vision_outputs[1] image_embeds = vision_outputs[1]
@ -1424,14 +1248,6 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
if return_loss: if return_loss:
loss = chinese_clip_loss(logits_per_text) loss = chinese_clip_loss(logits_per_text)
if not return_dict:
# fix the None pooled_output of text_outputs to conform with dict_output
pooled_output = text_outputs[1]
if pooled_output is None:
text_outputs = (text_outputs[0],) + text_outputs[2:]
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return ChineseCLIPOutput( return ChineseCLIPOutput(
loss=loss, loss=loss,
logits_per_image=logits_per_image, logits_per_image=logits_per_image,

View File

@ -17,7 +17,7 @@
import collections import collections
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -26,13 +26,14 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
@ -399,11 +400,6 @@ class ClapAudioSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -412,11 +408,11 @@ class ClapAudioSelfAttention(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states) hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states)) query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer) value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@ -1090,9 +1086,37 @@ class ClapTextEmbeddings(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape) return position_ids.unsqueeze(0).expand(input_shape)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText # Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
class ClapTextSelfAttention(nn.Module): class ClapTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
@ -1100,6 +1124,7 @@ class ClapTextSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -1109,20 +1134,12 @@ class ClapTextSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.attention_dropout = config.attention_probs_dropout_prob
config, "position_embedding_type", "absolute" self.scaling = self.attention_head_size**-0.5
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1132,96 +1149,33 @@ class ClapTextSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: attention_interface: Callable = eager_attention_forward
# reuse k,v, cross_attentions if self.config._attn_implementation != "eager":
key_layer = past_key_value[0] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) attn_output, attn_weights = attention_interface(
self,
use_cache = past_key_value is not None query_states,
if self.is_decoder: key_states,
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. value_states,
# Further calls to cross_attention layer can then reuse all cross-attention attention_mask,
# key/value_states (first "if" case) dropout=0.0 if not self.training else self.attention_dropout,
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of scaling=self.scaling,
# all previous decoder key/value_states. Further calls to uni-directional self-attention head_mask=head_mask,
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) **kwargs,
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
) )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
@ -1240,18 +1194,11 @@ class ClapTextSelfOutput(nn.Module):
return hidden_states return hidden_states
CLAP_TEXT_SELF_ATTENTION_CLASSES = { # Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
"eager": ClapTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
class ClapTextAttention(nn.Module): class ClapTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( self.self = ClapTextSelfAttention(config)
config, position_embedding_type=position_embedding_type
)
self.output = ClapTextSelfOutput(config) self.output = ClapTextSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
@ -1273,6 +1220,9 @@ class ClapTextAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1282,15 +1232,14 @@ class ClapTextAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, **kwargs,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -1328,22 +1277,19 @@ class ClapTextOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText # Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
class ClapTextLayer(GradientCheckpointingLayer): class ClapTextLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ClapTextAttention(config) self.attention = ClapTextAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ClapTextAttention(config, position_embedding_type="absolute")
self.intermediate = ClapTextIntermediate(config) self.intermediate = ClapTextIntermediate(config)
self.output = ClapTextOutput(config) self.output = ClapTextOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1353,60 +1299,23 @@ class ClapTextLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -1415,14 +1324,19 @@ class ClapTextLayer(GradientCheckpointingLayer):
return layer_output return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
class ClapTextEncoder(nn.Module): class ClapTextEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1435,65 +1349,36 @@ class ClapTextEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -1643,6 +1528,11 @@ class ClapTextModel(ClapPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1666,11 +1556,6 @@ class ClapTextModel(ClapPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -1684,11 +1569,8 @@ class ClapTextModel(ClapPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"): if hasattr(self.embeddings, "token_type_ids"):
@ -1702,21 +1584,6 @@ class ClapTextModel(ClapPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
@ -1725,33 +1592,23 @@ class ClapTextModel(ClapPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: return BaseModelOutputWithPooling(
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
@ -1892,6 +1749,7 @@ class ClapModel(ClapPreTrainedModel):
return audio_features return audio_features
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1947,7 +1805,7 @@ class ClapModel(ClapPreTrainedModel):
is_longer=is_longer, is_longer=is_longer,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
text_outputs = self.text_model( text_outputs = self.text_model(
@ -1956,7 +1814,7 @@ class ClapModel(ClapPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
@ -1981,10 +1839,6 @@ class ClapModel(ClapPreTrainedModel):
audio_loss = contrastive_loss(logits_per_audio.t()) audio_loss = contrastive_loss(logits_per_audio.t())
loss = (caption_loss + audio_loss) / 2.0 loss = (caption_loss + audio_loss) / 2.0
if not return_dict:
output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs)
return ((loss,) + output) if loss is not None else output
return ClapOutput( return ClapOutput(
loss=loss, loss=loss,
logits_per_audio=logits_per_audio, logits_per_audio=logits_per_audio,
@ -2013,6 +1867,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.text_model.embeddings.word_embeddings = value self.text_model.embeddings.word_embeddings = value
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -2045,17 +1900,13 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
text_embeds = self.text_projection(pooled_output) text_embeds = self.text_projection(pooled_output)
if not return_dict:
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return tuple(output for output in outputs if output is not None)
return ClapTextModelOutput( return ClapTextModelOutput(
text_embeds=text_embeds, text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state, last_hidden_state=text_outputs.last_hidden_state,
@ -2079,6 +1930,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
return self.audio_model.audio_encoder.patch_embed.proj return self.audio_model.audio_encoder.patch_embed.proj
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -2123,17 +1975,13 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
is_longer=is_longer, is_longer=is_longer,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
audio_embeds = self.audio_projection(pooled_output) audio_embeds = self.audio_projection(pooled_output)
if not return_dict:
outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:]
return tuple(output for output in outputs if output is not None)
return ClapAudioModelOutput( return ClapAudioModelOutput(
audio_embeds=audio_embeds, audio_embeds=audio_embeds,
last_hidden_state=audio_outputs.last_hidden_state, last_hidden_state=audio_outputs.last_hidden_state,

View File

@ -28,7 +28,7 @@ from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepa
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
@ -490,6 +490,7 @@ class CLIPSegEncoder(nn.Module):
self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -555,8 +556,6 @@ class CLIPSegEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )

View File

@ -311,7 +311,7 @@ class CsmAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -45,6 +45,7 @@ from ...modeling_outputs import (
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
from ...utils.deprecation import deprecate_kwarg
from .configuration_data2vec_audio import Data2VecAudioConfig from .configuration_data2vec_audio import Data2VecAudioConfig
@ -240,6 +241,7 @@ class Data2VecAudioAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -247,7 +249,7 @@ class Data2VecAudioAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -268,42 +270,9 @@ class Data2VecAudioAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -325,7 +294,7 @@ class Data2VecAudioAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class Data2VecAudioFeedForward(nn.Module): class Data2VecAudioFeedForward(nn.Module):

View File

@ -634,7 +634,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@auto_docstring @auto_docstring
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,

View File

@ -281,7 +281,7 @@ class DiaSelfAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -405,11 +405,6 @@ class DonutSwinSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -418,11 +413,11 @@ class DonutSwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states) hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states)) query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer) value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

View File

@ -189,7 +189,7 @@ class Emu3Attention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -1537,20 +1537,26 @@ class Emu3Model(Emu3PreTrainedModel):
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
) )
if pixel_values is not None and inputs_embeds is not None: if inputs_embeds is None:
raise ValueError( inputs_embeds = self.get_input_embeddings()(input_ids)
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes) image_embeds = self.get_image_features(pixel_values, image_sizes)
image_embeds = torch.cat(image_embeds, dim=0)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model( outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,

View File

@ -1033,20 +1033,26 @@ class Emu3Model(Emu3PreTrainedModel):
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
) )
if pixel_values is not None and inputs_embeds is not None: if inputs_embeds is None:
raise ValueError( inputs_embeds = self.get_input_embeddings()(input_ids)
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes) image_embeds = self.get_image_features(pixel_values, image_sizes)
image_embeds = torch.cat(image_embeds, dim=0)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model( outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,

View File

@ -26,14 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_esm import EsmConfig from .configuration_esm import EsmConfig
@ -187,12 +188,16 @@ class EsmEmbeddings(nn.Module):
self.mask_token_id = config.mask_token_id self.mask_token_id = config.mask_token_id
def forward( def forward(
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self,
input_ids=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
): ):
if position_ids is None: if position_ids is None:
if input_ids is not None: if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
else: else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
@ -281,11 +286,7 @@ class EsmSelfAttention(nn.Module):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: @deprecate_kwarg("past_key_value", version="4.54.0")
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -296,32 +297,22 @@ class EsmSelfAttention(nn.Module):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be # and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to. # such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: if is_cross_attention:
# reuse k,v, cross_attentions key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = past_key_value[0] value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else: else:
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.transpose_for_scores(self.value(hidden_states)) value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer)
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
@ -329,16 +320,6 @@ class EsmSelfAttention(nn.Module):
# ESM code and fix rotary embeddings. # ESM code and fix rotary embeddings.
query_layer = query_layer * self.attention_head_size**-0.5 query_layer = query_layer * self.attention_head_size**-0.5
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
if self.position_embedding_type == "rotary": if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
@ -385,7 +366,7 @@ class EsmSelfAttention(nn.Module):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder: if self.is_decoder:
outputs = outputs + (past_key_value,) outputs = outputs + (None,)
return outputs return outputs
@ -418,6 +399,7 @@ class EsmFlashAttention2(EsmSelfAttention):
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
self.dropout_prob = config.attention_probs_dropout_prob self.dropout_prob = config.attention_probs_dropout_prob
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -441,7 +423,6 @@ class EsmFlashAttention2(EsmSelfAttention):
head_mask, head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions, output_attentions,
) )
@ -450,9 +431,6 @@ class EsmFlashAttention2(EsmSelfAttention):
query_layer = self.transpose_for_scores(self.query(hidden_states)) query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states))
if past_key_value is not None:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
@ -514,7 +492,7 @@ class EsmFlashAttention2(EsmSelfAttention):
outputs = (attn_output, None) outputs = (attn_output, None)
if self.is_decoder: if self.is_decoder:
outputs = outputs + (past_key_value,) outputs = outputs + (None,)
return outputs return outputs
@ -551,6 +529,7 @@ class EsmAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -564,12 +543,11 @@ class EsmAttention(nn.Module):
hidden_states_ln = self.LayerNorm(hidden_states) hidden_states_ln = self.LayerNorm(hidden_states)
self_outputs = self.self( self_outputs = self.self(
hidden_states_ln, hidden_states_ln,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value, output_attentions=output_attentions,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -616,6 +594,7 @@ class EsmLayer(GradientCheckpointingLayer):
self.output = EsmOutput(config) self.output = EsmOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -626,25 +605,20 @@ class EsmLayer(GradientCheckpointingLayer):
past_key_value=None, past_key_value=None,
output_attentions=False, output_attentions=False,
): ):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache # if decoder, the last output is tuple of self-attn cache
if self.is_decoder: if self.is_decoder:
outputs = self_attention_outputs[1:-1] outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else: else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"): if not hasattr(self, "crossattention"):
raise AttributeError( raise AttributeError(
@ -652,31 +626,24 @@ class EsmLayer(GradientCheckpointingLayer):
" with cross-attention layers by setting `config.add_cross_attention=True`" " with cross-attention layers by setting `config.add_cross_attention=True`"
) )
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
cross_attn_past_key_value, output_attentions=output_attentions,
output_attentions,
) )
attention_output = cross_attention_outputs[0] attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = self.feed_forward_chunk(attention_output) layer_output = self.feed_forward_chunk(attention_output)
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output # if decoder, return the attn key/values as the last output
if self.is_decoder: if self.is_decoder:
outputs = outputs + (present_key_value,) outputs = outputs + (None,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -694,6 +661,9 @@ class EsmEncoder(nn.Module):
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("past_key_value", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -707,38 +677,26 @@ class EsmEncoder(nn.Module):
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
): ):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
past_key_value, output_attentions=output_attentions,
output_attentions,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention: if self.config.add_cross_attention:
@ -750,21 +708,8 @@ class EsmEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutputWithCrossAttentions(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
@ -863,6 +808,9 @@ class EsmModel(EsmPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -903,11 +851,6 @@ class EsmModel(EsmPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -921,11 +864,8 @@ class EsmModel(EsmPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
extended_attention_mask = attention_mask extended_attention_mask = attention_mask
@ -958,7 +898,6 @@ class EsmModel(EsmPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
@ -966,22 +905,16 @@ class EsmModel(EsmPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions, cross_attentions=encoder_outputs.cross_attentions,
@ -1025,6 +958,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings self.lm_head.decoder = new_embeddings
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1058,7 +992,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) prediction_scores = self.lm_head(sequence_output)
@ -1070,10 +1004,6 @@ class EsmForMaskedLM(EsmPreTrainedModel):
labels = labels.to(prediction_scores.device) labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput( return MaskedLMOutput(
loss=masked_lm_loss, loss=masked_lm_loss,
logits=prediction_scores, logits=prediction_scores,
@ -1125,6 +1055,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
self.post_init() self.post_init()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1154,7 +1085,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
@ -1184,10 +1115,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
@ -1210,6 +1137,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
self.post_init() self.post_init()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1237,7 +1165,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -1252,10 +1180,6 @@ class EsmForTokenClassification(EsmPreTrainedModel):
labels = labels.to(logits.device) labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
@ -1283,7 +1207,7 @@ class EsmClassificationHead(nn.Module):
return x return x
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`. are ignored. This is modified from fairseq's `utils.make_positions`.
@ -1295,7 +1219,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
""" """
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int() mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx return incremental_indices.long() + padding_idx

View File

@ -206,12 +206,20 @@ class FuyuModel(FuyuPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
if image_patches is not None:
patch_embeddings = self.get_image_features(image_patches) patch_embeddings = self.get_image_features(image_patches)
patch_embeddings = torch.cat(patch_embeddings, dim=0) patch_embeddings = torch.cat(patch_embeddings, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype) patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)

View File

@ -222,7 +222,7 @@ class GemmaAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -898,9 +898,11 @@ class Gemma3Model(Gemma3PreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]

View File

@ -800,9 +800,11 @@ class Gemma3Model(PaliGemmaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]

View File

@ -1135,9 +1135,17 @@ class Gemma3nTextAltUp(nn.Module):
corrected += predictions # add the original input corrected += predictions # add the original input
return corrected.contiguous().type_as(activated) return corrected.contiguous().type_as(activated)
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
"""
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
`scale_corrected_output`
"""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) return self.forward(corrected)
class Gemma3nTextRotaryEmbedding(nn.Module): class Gemma3nTextRotaryEmbedding(nn.Module):
@ -1290,7 +1298,7 @@ class Gemma3nTextAttention(nn.Module):
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing) # Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx] layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = ( self.kv_shared_layer_index = (
@ -1319,21 +1327,22 @@ class Gemma3nTextAttention(nn.Module):
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. # Device of past layer may be different from current one
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding: if isinstance(past_key_value, HybridCache) and self.is_sliding:
max_length = past_key_value.sliding_window max_length = past_key_value.sliding_window
if cache_position.shape[0] > max_length: indices = (
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, slice(0, max_length)
# slice into the entire cache. if cache_position.shape[0] > max_length
indices = slice(0, max_length) else cache_position.clamp(min=0, max=max_length - 1)
else: )
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
indices = cache_position.clamp(min=0, max=max_length - 1)
else:
indices = cache_position
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] # Device of past layer may be different from current one
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
query_states.device
)
else: else:
key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states) key_states = self.k_norm(key_states)
@ -1447,10 +1456,9 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
first_prediction = corrected_predictions[self.config.altup_active_idx] first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
first_prediction_clone = first_prediction.clone()
if self.config.altup_correct_scale: if self.config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction_clone) first_prediction = self.altup.scale_corrected_output(first_prediction)
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
first_prediction = self.per_layer_input_gate(first_prediction) first_prediction = self.per_layer_input_gate(first_prediction)
@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
config_class = Gemma3nConfig config_class = Gemma3nConfig
base_model_prefix = "" base_model_prefix = ""
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Gemma3nDecoderLayer"] _no_split_modules = ["Gemma3nTextDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"] _skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True _supports_flash_attn_3 = True
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
@ -1656,18 +1664,17 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
# Expand hidden_states to support per-layer inputs # Expand hidden_states to support per-layer inputs
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(torch.finfo().min) epsilon_tensor = torch.tensor(1e-5)
temp_hidden_states = [hidden_states_0] temp_hidden_states = [hidden_states_0]
for i in range(1, self.config.altup_num_inputs): for i in range(1, self.config.altup_num_inputs):
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) altup_proj = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.type(hidden_states_0.dtype) current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
current_hidden_state = current_hidden_state * ( new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
)
temp_hidden_states.append(current_hidden_state) temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
@ -1685,9 +1692,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings_global=position_embeddings_global, position_embeddings_global,
position_embeddings_local=position_embeddings_local, position_embeddings_local,
per_layer_input=per_layer_input, per_layer_input,
attention_mask=causal_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
@ -1712,11 +1719,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
for i in range(1, self.config.altup_num_inputs): for i in range(1, self.config.altup_num_inputs):
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
current_hidden_state = current_hidden_state * ( new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
)
temp_hidden_states.append(current_hidden_state) temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states) hidden_states = torch.stack(temp_hidden_states)
@ -1743,7 +1749,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
per_layer_inputs: Optional[torch.Tensor] = None, per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) per_layer_projection *= self.per_layer_projection_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
per_layer_projection = per_layer_projection.reshape( per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1], *inputs_embeds.shape[:-1],
self.config.num_hidden_layers, self.config.num_hidden_layers,
@ -1758,7 +1766,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")

View File

@ -1685,9 +1685,17 @@ class Gemma3nTextAltUp(nn.Module):
corrected += predictions # add the original input corrected += predictions # add the original input
return corrected.contiguous().type_as(activated) return corrected.contiguous().type_as(activated)
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
"""
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
`scale_corrected_output`
"""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) return self.forward(corrected)
class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding): class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
@ -1732,7 +1740,7 @@ class Gemma3nTextAttention(Gemma3Attention):
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing) # Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx] layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = ( self.kv_shared_layer_index = (
@ -1761,21 +1769,22 @@ class Gemma3nTextAttention(Gemma3Attention):
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. # Device of past layer may be different from current one
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding: if isinstance(past_key_value, HybridCache) and self.is_sliding:
max_length = past_key_value.sliding_window max_length = past_key_value.sliding_window
if cache_position.shape[0] > max_length: indices = (
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, slice(0, max_length)
# slice into the entire cache. if cache_position.shape[0] > max_length
indices = slice(0, max_length) else cache_position.clamp(min=0, max=max_length - 1)
else: )
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
indices = cache_position.clamp(min=0, max=max_length - 1)
else:
indices = cache_position
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] # Device of past layer may be different from current one
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
query_states.device
)
else: else:
key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states) key_states = self.k_norm(key_states)
@ -1880,10 +1889,9 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
first_prediction = corrected_predictions[self.config.altup_active_idx] first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
first_prediction_clone = first_prediction.clone()
if self.config.altup_correct_scale: if self.config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction_clone) first_prediction = self.altup.scale_corrected_output(first_prediction)
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
first_prediction = self.per_layer_input_gate(first_prediction) first_prediction = self.per_layer_input_gate(first_prediction)
@ -1906,7 +1914,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
config_class = Gemma3nConfig config_class = Gemma3nConfig
base_model_prefix = "" base_model_prefix = ""
_no_split_modules = ["Gemma3nDecoderLayer"] _no_split_modules = ["Gemma3nTextDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of Gemma2 isn't meant for training from scratch - only # important: this ported version of Gemma2 isn't meant for training from scratch - only
@ -1995,7 +2003,9 @@ class Gemma3nTextModel(Gemma3TextModel):
per_layer_inputs: Optional[torch.Tensor] = None, per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) per_layer_projection *= self.per_layer_projection_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
per_layer_projection = per_layer_projection.reshape( per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1], *inputs_embeds.shape[:-1],
self.config.num_hidden_layers, self.config.num_hidden_layers,
@ -2010,7 +2020,9 @@ class Gemma3nTextModel(Gemma3TextModel):
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
@can_return_tuple @can_return_tuple
@auto_docstring @auto_docstring
@ -2091,18 +2103,17 @@ class Gemma3nTextModel(Gemma3TextModel):
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
# Expand hidden_states to support per-layer inputs # Expand hidden_states to support per-layer inputs
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(torch.finfo().min) epsilon_tensor = torch.tensor(1e-5)
temp_hidden_states = [hidden_states_0] temp_hidden_states = [hidden_states_0]
for i in range(1, self.config.altup_num_inputs): for i in range(1, self.config.altup_num_inputs):
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) altup_proj = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.type(hidden_states_0.dtype) current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
current_hidden_state = current_hidden_state * ( new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
)
temp_hidden_states.append(current_hidden_state) temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
@ -2120,9 +2131,9 @@ class Gemma3nTextModel(Gemma3TextModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
position_embeddings_global=position_embeddings_global, position_embeddings_global,
position_embeddings_local=position_embeddings_local, position_embeddings_local,
per_layer_input=per_layer_input, per_layer_input,
attention_mask=causal_mask, attention_mask=causal_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
@ -2147,11 +2158,10 @@ class Gemma3nTextModel(Gemma3TextModel):
for i in range(1, self.config.altup_num_inputs): for i in range(1, self.config.altup_num_inputs):
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
current_hidden_state = current_hidden_state * ( new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
)
temp_hidden_states.append(current_hidden_state) temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states) hidden_states = torch.stack(temp_hidden_states)

View File

@ -39,6 +39,7 @@ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
auto_docstring, auto_docstring,
can_return_tuple,
logging, logging,
torch_int, torch_int,
) )
@ -770,6 +771,7 @@ class GitVisionEncoder(nn.Module):
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -836,8 +838,6 @@ class GitVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )

View File

@ -184,7 +184,7 @@ class GlmAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -242,7 +242,7 @@ class Glm4Attention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -279,14 +279,15 @@ def eager_attention_forward(
class Glm4vVisionAttention(nn.Module): class Glm4vVisionAttention(nn.Module):
def __init__(self, config: Glm4vVisionConfig) -> None: def __init__(self, config: Glm4vVisionConfig) -> None:
super().__init__() super().__init__()
self.config = config self.dim = config.hidden_size
self.num_heads = config.num_heads self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 self.num_key_value_groups = 1 # needed for eager attention
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = config.attention_dropout
self.is_causal = False self.is_causal = False
def forward( def forward(
@ -295,23 +296,31 @@ class Glm4vVisionAttention(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs], attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
query_states, key_states, value_states = ( query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
) )
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0) query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -322,13 +331,17 @@ class Glm4vVisionAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale, scaling=self.scaling,
is_causal=self.is_causal, cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs, **kwargs,
) )
attn_output = attn_output.squeeze(0)
attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output) attn_output = self.proj(attn_output)
return attn_output return attn_output
@ -348,6 +361,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
@ -355,6 +369,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -452,6 +467,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids return rotary_pos_emb, pos_ids
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
@ -481,14 +515,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
for blk in self.blocks: for blk in self.blocks:
if self.gradient_checkpointing and self.training: hidden_states = blk(
hidden_states = self._gradient_checkpointing_func( hidden_states,
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
) )
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = self.post_layernorm(hidden_states) hidden_states = self.post_layernorm(hidden_states)
@ -1202,50 +1237,59 @@ class Glm4vModel(Glm4vPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = image_mask.sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = attention_mask attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
@ -1536,6 +1580,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1550,6 +1595,26 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
""" """
if inputs_embeds is not None:
is_image = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_start = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_end = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
is_image = input_ids == self.config.image_start_token_id is_image = input_ids == self.config.image_start_token_id
is_video_start = input_ids == self.config.video_start_token_id is_video_start = input_ids == self.config.video_start_token_id
is_video_end = input_ids == self.config.video_end_token_id is_video_end = input_ids == self.config.video_end_token_id
@ -1588,7 +1653,9 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
def _expand_dict_for_generation_visual(dict_to_expand): def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None) image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times): def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
@ -1644,9 +1711,6 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
# input_ids is required for expanding visual inputs
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs) model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None: if input_ids is not None:

View File

@ -50,8 +50,8 @@ from ..qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLPreTrainedModel, Qwen2_5_VLPreTrainedModel,
Qwen2_5_VLRotaryEmbedding, Qwen2_5_VLRotaryEmbedding,
Qwen2_5_VLTextModel, Qwen2_5_VLTextModel,
Qwen2_5_VLVisionAttention,
Qwen2_5_VLVisionBlock, Qwen2_5_VLVisionBlock,
apply_rotary_pos_emb_vision,
) )
from ..qwen2_5_vl.processing_qwen2_5_vl import ( from ..qwen2_5_vl.processing_qwen2_5_vl import (
Qwen2_5_VLProcessor, Qwen2_5_VLProcessor,
@ -505,62 +505,12 @@ class Glm4vVisionEmbeddings(nn.Module):
return embeddings return embeddings
class Glm4vVisionAttention(nn.Module): class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
def __init__(self, config: Glm4vVisionConfig) -> None: def __init__(self, config: Glm4vVisionConfig) -> None:
super().__init__() super().__init__()
self.config = config
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_groups = 1
self.scale = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.squeeze(0)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock): class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
@ -653,6 +603,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb, pos_ids return rotary_pos_emb, pos_ids
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
@ -682,14 +651,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
for blk in self.blocks: for blk in self.blocks:
if self.gradient_checkpointing and self.training: hidden_states = blk(
hidden_states = self._gradient_checkpointing_func( hidden_states,
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
) )
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
hidden_states = self.post_layernorm(hidden_states) hidden_states = self.post_layernorm(hidden_states)
@ -1267,50 +1237,59 @@ class Glm4vModel(Qwen2_5_VLModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = image_mask.sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = attention_mask attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
@ -1530,6 +1509,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1544,6 +1524,26 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
""" """
if inputs_embeds is not None:
is_image = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_start = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_end = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
is_image = input_ids == self.config.image_start_token_id is_image = input_ids == self.config.image_start_token_id
is_video_start = input_ids == self.config.video_start_token_id is_video_start = input_ids == self.config.video_start_token_id
is_video_end = input_ids == self.config.video_end_token_id is_video_end = input_ids == self.config.video_end_token_id

View File

@ -648,24 +648,27 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -339,24 +339,27 @@ class GotOcr2Model(LlavaModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -148,7 +148,7 @@ class GraniteAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -224,7 +224,7 @@ class HeliumAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_hubert import HubertConfig from .configuration_hubert import HubertConfig
@ -300,6 +301,7 @@ class HubertAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -307,7 +309,7 @@ class HubertAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -328,42 +330,9 @@ class HubertAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -385,7 +354,7 @@ class HubertAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class HubertFeedForward(nn.Module): class HubertFeedForward(nn.Module):

View File

@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
can_return_tuple,
logging, logging,
) )
from .configuration_idefics import IdeficsVisionConfig from .configuration_idefics import IdeficsVisionConfig
@ -351,6 +352,7 @@ class IdeficsVisionEncoder(nn.Module):
self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -417,8 +419,6 @@ class IdeficsVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )

View File

@ -933,10 +933,18 @@ class Idefics2Model(Idefics2PreTrainedModel):
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
""" """
special_image_token_mask = input_ids == self.image_token_id if input_ids is None:
new_inputs_embeds = inputs_embeds.clone() special_image_mask = inputs_embeds == self.get_input_embeddings()(
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
return new_inputs_embeds )
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
return inputs_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
""" """
@ -1041,25 +1049,8 @@ class Idefics2Model(Idefics2PreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and not isinstance(past_key_values, Cache):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache:
if not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache() past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids) inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
@ -1072,7 +1063,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
elif image_hidden_states is not None: elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(
@ -1094,9 +1085,6 @@ class Idefics2Model(Idefics2PreTrainedModel):
**kwargs, **kwargs,
) )
if return_legacy_cache and use_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
return Idefics2BaseModelOutputWithPast( return Idefics2BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state, last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
@ -1304,37 +1292,11 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
**kwargs, **kwargs,
) )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step if image_hidden_states is not None or cache_position[0] != 0:
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs["input_ids"] = input_ids
if image_hidden_states is not None:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None model_inputs["pixel_attention_mask"] = None
return model_inputs return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
@staticmethod
# Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
__all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"] __all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]

View File

@ -663,15 +663,18 @@ class Idefics3Model(Idefics3PreTrainedModel):
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
""" """
special_image_token_mask = input_ids == self.image_token_id if input_ids is None:
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. special_image_mask = inputs_embeds == self.get_input_embeddings()(
new_inputs_embeds = inputs_embeds.clone() torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
# Flatten `image_hidden_states` if not flat yet )
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) special_image_mask = special_image_mask.all(-1)
# cast to the dtype of the input_embeds to support quantized models else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
new_inputs_embeds[special_image_token_mask] = image_hidden_states inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
return new_inputs_embeds return inputs_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
""" """
@ -773,11 +776,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache:
if past_key_values is None:
past_key_values = DynamicCache() past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
@ -790,7 +790,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
elif image_hidden_states is not None: elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None: if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(
@ -1042,28 +1042,11 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
**kwargs, **kwargs,
) )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step if image_hidden_states is not None or cache_position[0] != 0:
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs["input_ids"] = input_ids
if image_hidden_states is not None:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None model_inputs["pixel_attention_mask"] = None
return model_inputs return model_inputs
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
__all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"] __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]

View File

@ -1255,6 +1255,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@ -1328,12 +1329,20 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
# step 3: use the language model, conditioned on the query outputs and the prompt # step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output) language_model_inputs = self.language_projection(query_output)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
special_image_mask = input_ids == self.config.image_token_id
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
else:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten() language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
@ -1513,6 +1522,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -1604,15 +1614,26 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten() special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. " "Expanding inputs for image tokens in InstructBLIP should be done in processing. "
@ -1673,6 +1694,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
qformer_attention_mask: Optional[torch.LongTensor] = None, qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -1690,6 +1712,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings. Whether to interpolate the positional encoding of the image embeddings.
@ -1712,23 +1736,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if inputs_embeds is None:
if input_ids is None: if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id] start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
input_ids = input_ids.repeat(batch_size, 1) input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. " "Expanding inputs for image tokens in InstructBLIP should be done in processing. "

View File

@ -1251,6 +1251,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@ -1334,12 +1335,20 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
special_image_mask = input_ids == self.config.video_token_id
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
else:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten() language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
@ -1485,6 +1494,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -1599,15 +1609,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
@ -1668,6 +1689,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
qformer_attention_mask: Optional[torch.LongTensor] = None, qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -1685,6 +1707,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings. Whether to interpolate the positional encoding of the image embeddings.
@ -1708,23 +1732,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if inputs_embeds is None:
if input_ids is None: if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id] start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
input_ids = input_ids.repeat(batch_size, 1) input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "

View File

@ -202,6 +202,7 @@ class InstructBlipVideoModel(InstructBlipModel):
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@ -255,12 +256,20 @@ class InstructBlipVideoModel(InstructBlipModel):
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
special_image_mask = input_ids == self.config.video_token_id
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
else:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten() language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
@ -372,6 +381,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -451,15 +461,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
@ -520,6 +541,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
qformer_attention_mask: Optional[torch.LongTensor] = None, qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -537,6 +559,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings. Whether to interpolate the positional encoding of the image embeddings.
@ -560,23 +584,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if inputs_embeds is None:
if input_ids is None: if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id] start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
input_ids = input_ids.repeat(batch_size, 1) input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "

View File

@ -710,14 +710,14 @@ class InternVLModel(InternVLPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -641,14 +641,14 @@ class InternVLModel(LlavaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -1102,23 +1102,21 @@ class JanusModel(JanusPreTrainedModel):
) )
use_cache = False use_cache = False
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values) if input_ids is None:
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_attention_mask = image_attention_mask.all(-1)
else:
image_attention_mask = input_ids == self.config.image_token_id image_attention_mask = input_ids == self.config.image_token_id
embed_dim = inputs_embeds.shape[-1] image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_embeds.reshape(-1, embed_dim) image_embeds = self.get_image_features(pixel_values)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

View File

@ -955,23 +955,21 @@ class JanusModel(JanusPreTrainedModel):
) )
use_cache = False use_cache = False
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values) if input_ids is None:
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_attention_mask = image_attention_mask.all(-1)
else:
image_attention_mask = input_ids == self.config.image_token_id image_attention_mask = input_ids == self.config.image_token_id
embed_dim = inputs_embeds.shape[-1] image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_embeds.reshape(-1, embed_dim) image_embeds = self.get_image_features(pixel_values)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

View File

@ -451,6 +451,7 @@ class Kosmos2VisionEncoder(nn.Module):
self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -517,8 +518,6 @@ class Kosmos2VisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )
@ -1468,25 +1467,19 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
image_embeds_position_mask=None, image_embeds_position_mask=None,
past_key_values=None, past_key_values=None,
attention_mask=None, attention_mask=None,
inputs_embeds=None,
use_cache=None, use_cache=None,
cache_position=None, cache_position=None,
**model_kwargs, **model_kwargs,
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# Kosmos2 has offset for position ids, so we need to create them correctly
position_ids = create_position_ids_from_input_ids(
input_ids,
padding_idx=self.config.pad_token_id,
past_key_values_length=0,
)
if past_key_values is not None: if past_key_values is not None:
image_embeds = None image_embeds = None
image_embeds_position_mask = None image_embeds_position_mask = None
# appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation)
elif image_embeds_position_mask is not None: elif image_embeds_position_mask is not None:
batch_size, seq_len = input_ids.size() batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size()
mask_len = image_embeds_position_mask.size()[-1] mask_len = image_embeds_position_mask.size()[-1]
image_embeds_position_mask = torch.cat( image_embeds_position_mask = torch.cat(
( (
@ -1502,11 +1495,13 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
attention_mask=attention_mask, attention_mask=attention_mask,
image_embeds=image_embeds, image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask, image_embeds_position_mask=image_embeds_position_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
**model_kwargs, **model_kwargs,
) )
# Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
model_inputs.pop("position_ids", None)
return model_inputs return model_inputs
@ -1876,6 +1871,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
): ):
# in order to allow `inputs` argument (as in `GenerationMixin`) # in order to allow `inputs` argument (as in `GenerationMixin`)
@ -1901,6 +1897,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
attention_mask=attention_mask, attention_mask=attention_mask,
image_embeds=image_embeds, image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask, image_embeds_position_mask=image_embeds_position_mask,
inputs_embeds=inputs_embeds,
**kwargs, **kwargs,
) )

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""LayoutLM model configuration""" """LayoutLM model configuration"""
import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional
@ -130,10 +131,22 @@ class LayoutLMConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type self._position_embedding_type = position_embedding_type
self.use_cache = use_cache self.use_cache = use_cache
self.max_2d_position_embeddings = max_2d_position_embeddings self.max_2d_position_embeddings = max_2d_position_embeddings
@property
def position_embedding_type(self):
warnings.warn(
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
FutureWarning,
)
return self._position_embedding_type
@position_embedding_type.setter
def position_embedding_type(self, value):
self._position_embedding_type = value
class LayoutLMOnnxConfig(OnnxConfig): class LayoutLMOnnxConfig(OnnxConfig):
def __init__( def __init__(

View File

@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch LayoutLM model.""" """PyTorch LayoutLM model."""
import math from typing import Callable, Optional, Union
from typing import Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -25,16 +24,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutput,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPooling,
MaskedLMOutput, MaskedLMOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_layoutlm import LayoutLMConfig from .configuration_layoutlm import LayoutLMConfig
@ -120,9 +120,37 @@ class LayoutLMEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM # Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
class LayoutLMSelfAttention(nn.Module): class LayoutLMSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
@ -130,6 +158,7 @@ class LayoutLMSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -139,20 +168,12 @@ class LayoutLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.attention_dropout = config.attention_probs_dropout_prob
config, "position_embedding_type", "absolute" self.scaling = self.attention_head_size**-0.5
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -162,96 +183,33 @@ class LayoutLMSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: attention_interface: Callable = eager_attention_forward
# reuse k,v, cross_attentions if self.config._attn_implementation != "eager":
key_layer = past_key_value[0] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) attn_output, attn_weights = attention_interface(
self,
use_cache = past_key_value is not None query_states,
if self.is_decoder: key_states,
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. value_states,
# Further calls to cross_attention layer can then reuse all cross-attention attention_mask,
# key/value_states (first "if" case) dropout=0.0 if not self.training else self.attention_dropout,
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of scaling=self.scaling,
# all previous decoder key/value_states. Further calls to uni-directional self-attention head_mask=head_mask,
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) **kwargs,
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
) )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
@ -270,18 +228,11 @@ class LayoutLMSelfOutput(nn.Module):
return hidden_states return hidden_states
LAYOUTLM_SELF_ATTENTION_CLASSES = { # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
"eager": LayoutLMSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
class LayoutLMAttention(nn.Module): class LayoutLMAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( self.self = LayoutLMSelfAttention(config)
config, position_embedding_type=position_embedding_type
)
self.output = LayoutLMSelfOutput(config) self.output = LayoutLMSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
@ -303,6 +254,9 @@ class LayoutLMAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -312,15 +266,14 @@ class LayoutLMAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, **kwargs,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -358,22 +311,19 @@ class LayoutLMOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
class LayoutLMLayer(GradientCheckpointingLayer): class LayoutLMLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = LayoutLMAttention(config) self.attention = LayoutLMAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
self.intermediate = LayoutLMIntermediate(config) self.intermediate = LayoutLMIntermediate(config)
self.output = LayoutLMOutput(config) self.output = LayoutLMOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -383,60 +333,23 @@ class LayoutLMLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -445,14 +358,19 @@ class LayoutLMLayer(GradientCheckpointingLayer):
return layer_output return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
class LayoutLMEncoder(nn.Module): class LayoutLMEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -465,65 +383,36 @@ class LayoutLMEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -648,6 +537,9 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -663,7 +555,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: ) -> Union[tuple, BaseModelOutputWithPooling]:
r""" r"""
bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*): bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
Bounding boxes of each input sequence tokens. Selected in the range `[0, Bounding boxes of each input sequence tokens. Selected in the range `[0,
@ -756,20 +648,16 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not return_dict: return BaseModelOutputWithPooling(
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
@ -796,6 +684,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
self.cls.predictions.decoder = new_embeddings self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias self.cls.predictions.bias = new_embeddings.bias
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -871,11 +762,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -889,10 +778,6 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
labels.view(-1), labels.view(-1),
) )
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput( return MaskedLMOutput(
loss=masked_lm_loss, loss=masked_lm_loss,
logits=prediction_scores, logits=prediction_scores,
@ -921,6 +806,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings return self.layoutlm.embeddings.word_embeddings
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -996,7 +882,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
@ -1026,9 +912,6 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
elif self.config.problem_type == "multi_label_classification": elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, loss=loss,
@ -1059,6 +942,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings return self.layoutlm.embeddings.word_embeddings
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1132,7 +1016,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -1145,10 +1029,6 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
@ -1176,6 +1056,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings return self.layoutlm.embeddings.word_embeddings
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1253,7 +1134,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -1280,10 +1161,6 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput( return QuestionAnsweringModelOutput(
loss=total_loss, loss=total_loss,
start_logits=start_logits, start_logits=start_logits,

View File

@ -224,7 +224,7 @@ class LlamaAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)

View File

@ -1358,27 +1358,28 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes, image_sizes=image_sizes,
) )
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1)) vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat) projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
final_mask = special_image_mask.to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
final_mask_1d = final_mask[..., 0].reshape(-1) n_image_tokens = (special_image_mask).sum()
num_tokens_to_fill = final_mask_1d.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if num_tokens_to_fill != projected_vision_flat.size(0): if n_image_tokens != projected_vision_flat.size(0):
raise ValueError( raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " f"Mismatch: final_mask wants {n_image_tokens} embeddings, "
f"but multi_modal_projector returned {projected_vision_flat.size(0)}" f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
) )
projected_vision_flat = projected_vision_flat.to(inputs_embeds.device, inputs_embeds.dtype)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -284,14 +284,14 @@ class LlavaModel(LlavaPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -468,11 +468,6 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -485,10 +480,18 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -519,12 +519,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -537,10 +531,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -559,10 +561,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
video_features = torch.cat(video_features, dim=0) video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -440,12 +440,6 @@ class LlavaNextVideoModel(LlavaNextModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -458,10 +452,18 @@ class LlavaNextVideoModel(LlavaNextModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -480,10 +482,18 @@ class LlavaNextVideoModel(LlavaNextModel):
video_features = torch.cat(video_features, dim=0) video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -551,12 +551,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -571,10 +565,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -595,10 +597,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
video_features = torch.cat((video_features, image_newline), dim=1) video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1) video_features = video_features.flatten(0, 1)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_video_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_video_mask).sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -535,12 +535,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -555,10 +549,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -579,10 +581,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
video_features = torch.cat((video_features, image_newline), dim=1) video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1) video_features = video_features.flatten(0, 1)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_video_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_video_mask).sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
"""MarkupLM model configuration""" """MarkupLM model configuration"""
import warnings
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@ -141,7 +143,7 @@ class MarkupLMConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type self._position_embedding_type = position_embedding_type
self.use_cache = use_cache self.use_cache = use_cache
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
# additional properties # additional properties
@ -152,5 +154,17 @@ class MarkupLMConfig(PretrainedConfig):
self.subs_pad_id = subs_pad_id self.subs_pad_id = subs_pad_id
self.xpath_unit_hidden_size = xpath_unit_hidden_size self.xpath_unit_hidden_size = xpath_unit_hidden_size
@property
def position_embedding_type(self):
warnings.warn(
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
FutureWarning,
)
return self._position_embedding_type
@position_embedding_type.setter
def position_embedding_type(self, value):
self._position_embedding_type = value
__all__ = ["MarkupLMConfig"] __all__ = ["MarkupLMConfig"]

View File

@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch MarkupLM model.""" """PyTorch MarkupLM model."""
import math
import os import os
from typing import Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -26,20 +25,22 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutput,
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPooling,
MaskedLMOutput, MaskedLMOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import ( from ...modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from ...utils import auto_docstring, logging from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_markuplm import MarkupLMConfig from .configuration_markuplm import MarkupLMConfig
@ -326,9 +327,37 @@ class MarkupLMOnlyMLMHead(nn.Module):
return prediction_scores return prediction_scores
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM # Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
class MarkupLMSelfAttention(nn.Module): class MarkupLMSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
@ -336,6 +365,7 @@ class MarkupLMSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -345,20 +375,12 @@ class MarkupLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.attention_dropout = config.attention_probs_dropout_prob
config, "position_embedding_type", "absolute" self.scaling = self.attention_head_size**-0.5
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -368,111 +390,41 @@ class MarkupLMSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: attention_interface: Callable = eager_attention_forward
# reuse k,v, cross_attentions if self.config._attn_implementation != "eager":
key_layer = past_key_value[0] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) attn_output, attn_weights = attention_interface(
self,
use_cache = past_key_value is not None query_states,
if self.is_decoder: key_states,
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. value_states,
# Further calls to cross_attention layer can then reuse all cross-attention attention_mask,
# key/value_states (first "if" case) dropout=0.0 if not self.training else self.attention_dropout,
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of scaling=self.scaling,
# all previous decoder key/value_states. Further calls to uni-directional self-attention head_mask=head_mask,
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) **kwargs,
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
) )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in MarkupLMModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
MARKUPLM_SELF_ATTENTION_CLASSES = { # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
"eager": MarkupLMSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
class MarkupLMAttention(nn.Module): class MarkupLMAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( self.self = MarkupLMSelfAttention(config)
config, position_embedding_type=position_embedding_type
)
self.output = MarkupLMSelfOutput(config) self.output = MarkupLMSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
@ -494,6 +446,9 @@ class MarkupLMAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -503,37 +458,33 @@ class MarkupLMAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, **kwargs,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
class MarkupLMLayer(GradientCheckpointingLayer): class MarkupLMLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = MarkupLMAttention(config) self.attention = MarkupLMAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute")
self.intermediate = MarkupLMIntermediate(config) self.intermediate = MarkupLMIntermediate(config)
self.output = MarkupLMOutput(config) self.output = MarkupLMOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -543,60 +494,23 @@ class MarkupLMLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -605,14 +519,19 @@ class MarkupLMLayer(GradientCheckpointingLayer):
return layer_output return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
class MarkupLMEncoder(nn.Module): class MarkupLMEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -625,65 +544,36 @@ class MarkupLMEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -749,6 +639,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -763,7 +654,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: ) -> Union[tuple, BaseModelOutputWithPooling]:
r""" r"""
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
Tag IDs for each token in the input sequence, padded up to config.max_depth. Tag IDs for each token in the input sequence, padded up to config.max_depth.
@ -839,21 +730,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict: return BaseModelOutputWithPooling(
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
# Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache # Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache
@ -879,6 +765,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -939,7 +826,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -966,10 +853,6 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput( return QuestionAnsweringModelOutput(
loss=total_loss, loss=total_loss,
start_logits=start_logits, start_logits=start_logits,
@ -1000,6 +883,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1058,7 +942,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
@ -1072,10 +956,6 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
labels.view(-1), labels.view(-1),
) )
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, loss=loss,
logits=prediction_scores, logits=prediction_scores,
@ -1107,6 +987,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1164,7 +1045,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
@ -1194,9 +1075,6 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
elif self.config.problem_type == "multi_label_classification": elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss() loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, loss=loss,

View File

@ -354,11 +354,6 @@ class MaskFormerSwinSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -367,11 +362,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states) hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states)) query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer) value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

View File

@ -308,11 +308,6 @@ class Mistral3Model(Mistral3PreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -324,10 +319,18 @@ class Mistral3Model(Mistral3PreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -204,11 +204,6 @@ class Mistral3Model(LlavaModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -220,10 +215,18 @@ class Mistral3Model(LlavaModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -182,7 +182,6 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen
class MusicgenAttention(nn.Module): class MusicgenAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -189,7 +189,7 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody
class MusicgenMelodyAttention(nn.Module): class MusicgenMelodyAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -503,7 +503,7 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states
class NllbMoeAttention(nn.Module): class NllbMoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -331,9 +331,11 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]

View File

@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import auto_docstring, logging from ...utils import auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_patchtsmixer import PatchTSMixerConfig from .configuration_patchtsmixer import PatchTSMixerConfig
@ -303,6 +304,7 @@ class PatchTSMixerAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -310,7 +312,7 @@ class PatchTSMixerAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -331,42 +333,9 @@ class PatchTSMixerAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -388,7 +357,7 @@ class PatchTSMixerAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class PatchMixerBlock(nn.Module): class PatchMixerBlock(nn.Module):

View File

@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_patchtst import PatchTSTConfig from .configuration_patchtst import PatchTSTConfig
@ -100,6 +101,7 @@ class PatchTSTAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -107,7 +109,7 @@ class PatchTSTAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -128,42 +130,9 @@ class PatchTSTAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -185,7 +154,7 @@ class PatchTSTAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class PatchTSTBatchNorm(nn.Module): class PatchTSTBatchNorm(nn.Module):

View File

@ -607,6 +607,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
f" and `num_heads`: {self.num_heads})." f" and `num_heads`: {self.num_heads})."
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False self.is_decoder = False
self.is_causal = False self.is_causal = False
@ -619,6 +620,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@ -634,15 +636,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
value_states = value_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.full(
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
@ -652,13 +645,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0 if not self.training else self.dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_length_q=max_seqlen,
max_seqlen_k=max_seqlen, max_length_k=max_seqlen,
is_causal=False, is_causal=False,
**kwargs, **kwargs,
) )
@ -686,6 +679,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -704,6 +698,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
hidden_states = self.self_attn( hidden_states = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -785,6 +780,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
def set_input_embeddings(self, value: nn.Module): def set_input_embeddings(self, value: nn.Module):
self.conv1 = value self.conv1 = value
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -833,9 +847,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
padded_mask_after_cnn.sum(1).cumsum(0), padded_mask_after_cnn.sum(1).cumsum(0),
) )
).to(torch.int32) ).to(torch.int32)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) layer_outputs = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
@ -928,12 +948,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.num_key_value_groups = 1 # needed for eager attention self.num_key_value_groups = 1 # needed for eager attention
self.config = config self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
@ -943,18 +966,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full( query_states = query_states.transpose(0, 1).unsqueeze(0)
[1, 1, seq_length, seq_length], key_states = key_states.transpose(0, 1).unsqueeze(0)
torch.finfo(query_states.dtype).min, value_states = value_states.transpose(0, 1).unsqueeze(0)
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
@ -966,13 +980,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_length_q=max_seqlen,
max_seqlen_k=max_seqlen, max_length_k=max_seqlen,
is_causal=False, is_causal=False,
**kwargs, **kwargs,
) )
@ -1009,10 +1023,15 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states return hidden_states
@ -1171,6 +1190,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
return window_index, cu_window_seqlens return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
""" """
Args: Args:
@ -1217,10 +1255,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
cu_seqlens_now = cu_seqlens cu_seqlens_now = cu_seqlens
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk( hidden_states = blk(
hidden_states, hidden_states,
cu_seqlens=cu_seqlens_now, cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
@ -1862,41 +1903,49 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video # 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
if input_features is not None: if input_features is not None:
audio_features = self.get_audio_features( audio_features = self.get_audio_features(
input_features, input_features,
feature_attention_mask=feature_attention_mask, feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths, audio_feature_lengths=audio_feature_lengths,
) )
audio_mask = ( if input_ids is None:
(input_ids == self.config.audio_token_id) audio_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
audio_mask = audio_mask.all(-1)
else:
audio_mask = input_ids == self.config.audio_token_id
audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_mask = ( if input_ids is None:
(input_ids == self.config.image_token_id) image_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_mask = ( if input_ids is None:
(input_ids == self.config.video_token_id) video_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

View File

@ -1611,6 +1611,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
f" and `num_heads`: {self.num_heads})." f" and `num_heads`: {self.num_heads})."
) )
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False self.is_decoder = False
self.is_causal = False self.is_causal = False
@ -1623,6 +1624,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
@ -1638,15 +1640,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
value_states = value_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_mask = torch.full(
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
@ -1656,13 +1649,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0 if not self.training else self.dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_length_q=max_seqlen,
max_seqlen_k=max_seqlen, max_length_k=max_seqlen,
is_causal=False, is_causal=False,
**kwargs, **kwargs,
) )
@ -1682,6 +1675,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
@ -1689,6 +1683,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer):
hidden_states = self.self_attn( hidden_states = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@ -1770,6 +1765,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
def set_input_embeddings(self, value: nn.Module): def set_input_embeddings(self, value: nn.Module):
self.conv1 = value self.conv1 = value
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -1818,9 +1832,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
padded_mask_after_cnn.sum(1).cumsum(0), padded_mask_after_cnn.sum(1).cumsum(0),
) )
).to(torch.int32) ).to(torch.int32)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) layer_outputs = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
@ -1906,12 +1926,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.num_key_value_groups = 1 # needed for eager attention self.num_key_value_groups = 1 # needed for eager attention
self.config = config self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
@ -1921,18 +1944,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full( query_states = query_states.transpose(0, 1).unsqueeze(0)
[1, 1, seq_length, seq_length], key_states = key_states.transpose(0, 1).unsqueeze(0)
torch.finfo(query_states.dtype).min, value_states = value_states.transpose(0, 1).unsqueeze(0)
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
@ -1944,13 +1958,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_length_q=max_seqlen,
max_seqlen_k=max_seqlen, max_length_k=max_seqlen,
is_causal=False, is_causal=False,
**kwargs, **kwargs,
) )
@ -1970,10 +1984,15 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs,
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states return hidden_states
@ -1987,6 +2006,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)])
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
""" """
Args: Args:
@ -2033,10 +2071,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
cu_seqlens_now = cu_seqlens cu_seqlens_now = cu_seqlens
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk( hidden_states = blk(
hidden_states, hidden_states,
cu_seqlens=cu_seqlens_now, cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
@ -2309,41 +2350,49 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video # 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
if input_features is not None: if input_features is not None:
audio_features = self.get_audio_features( audio_features = self.get_audio_features(
input_features, input_features,
feature_attention_mask=feature_attention_mask, feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths, audio_feature_lengths=audio_feature_lengths,
) )
audio_mask = ( if input_ids is None:
(input_ids == self.config.audio_token_id) audio_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
audio_mask = audio_mask.all(-1)
else:
audio_mask = input_ids == self.config.audio_token_id
audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_mask = ( if input_ids is None:
(input_ids == self.config.image_token_id) image_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_mask = ( if input_ids is None:
(input_ids == self.config.video_token_id) video_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

View File

@ -206,6 +206,8 @@ class Qwen2_5_VLVisionAttention(nn.Module):
self.proj = nn.Linear(self.dim, self.dim) self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.config = config self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward( def forward(
self, self,
@ -213,6 +215,7 @@ class Qwen2_5_VLVisionAttention(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
@ -233,18 +236,9 @@ class Qwen2_5_VLVisionAttention(nn.Module):
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_mask = torch.full( query_states = query_states.transpose(0, 1).unsqueeze(0)
[1, 1, seq_length, seq_length], key_states = key_states.transpose(0, 1).unsqueeze(0)
torch.finfo(value_states.dtype).min, value_states = value_states.transpose(0, 1).unsqueeze(0)
device=value_states.device,
dtype=value_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
@ -256,13 +250,13 @@ class Qwen2_5_VLVisionAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_length_q=max_seqlen,
max_seqlen_k=max_seqlen, max_length_k=max_seqlen,
is_causal=False, is_causal=False,
**kwargs, **kwargs,
) )
@ -286,6 +280,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
@ -293,6 +288,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -426,6 +422,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
""" """
Args: Args:
@ -472,8 +487,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
cu_seqlens_now = cu_seqlens cu_seqlens_now = cu_seqlens
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk( hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
) )
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
@ -1224,39 +1245,49 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = (image_mask).sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.video_token_id).sum()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
@ -1565,6 +1596,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1582,10 +1614,31 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id vision_start_token_id = self.config.vision_start_token_id
if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id vision_start_mask = input_ids == vision_start_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_nums = torch.sum(vision_first_mask & image_mask, dim=1) image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
@ -1611,7 +1664,9 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def _expand_dict_for_generation_visual(dict_to_expand): def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None) image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times): def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
@ -1667,9 +1722,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
# input_ids is required for expanding visual inputs
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs) model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None: if input_ids is not None:

View File

@ -159,6 +159,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
@ -166,6 +167,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -287,6 +289,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return window_index, cu_window_seqlens return window_index, cu_window_seqlens
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
""" """
Args: Args:
@ -333,8 +354,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
cu_seqlens_now = cu_seqlens cu_seqlens_now = cu_seqlens
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
hidden_states = blk( hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
) )
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
@ -582,39 +609,49 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = (image_mask).sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.video_token_id).sum()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

View File

@ -324,6 +324,8 @@ class VisionAttention(nn.Module):
self.proj = nn.Linear(self.dim, self.dim) self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.config = config self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward( def forward(
self, self,
@ -331,6 +333,7 @@ class VisionAttention(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
@ -351,18 +354,9 @@ class VisionAttention(nn.Module):
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_mask = torch.full( query_states = query_states.transpose(0, 1).unsqueeze(0)
[1, 1, seq_length, seq_length], key_states = key_states.transpose(0, 1).unsqueeze(0)
torch.finfo(value_states.dtype).min, value_states = value_states.transpose(0, 1).unsqueeze(0)
device=value_states.device,
dtype=value_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
@ -374,13 +368,13 @@ class VisionAttention(nn.Module):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attention_mask, attention_mask=attention_mask,
dropout=0.0, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seqlens_k=cu_seqlens, cu_seq_lens_k=cu_seqlens,
max_seqlen_q=max_seqlen, max_length_q=max_seqlen,
max_seqlen_k=max_seqlen, max_length_k=max_seqlen,
is_causal=False, is_causal=False,
**kwargs, **kwargs,
) )
@ -406,6 +400,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states + self.attn( hidden_states = hidden_states + self.attn(
@ -413,6 +408,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs, **kwargs,
) )
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@ -725,6 +721,25 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb return rotary_pos_emb
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -750,10 +765,15 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
) )
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
for blk in self.blocks: for blk in self.blocks:
hidden_states = blk( hidden_states = blk(
hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
) )
return self.merger(hidden_states) return self.merger(hidden_states)
@ -1162,39 +1182,50 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = image_mask.sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_video_tokens = (video_mask).sum(dim=1).sum(dim=0)[0]
else:
video_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
video_mask = video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_video_tokens = (input_ids == self.config.image_token_id).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
@ -1460,6 +1491,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1477,10 +1509,31 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id vision_start_token_id = self.config.vision_start_token_id
if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id vision_start_mask = input_ids == vision_start_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_nums = torch.sum(vision_first_mask & image_mask, dim=1) image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
@ -1506,7 +1559,9 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def _expand_dict_for_generation_visual(dict_to_expand): def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None) image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times): def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
@ -1562,9 +1617,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
# input_ids is required for expanding visual inputs
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs) model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None: if input_ids is not None:

View File

@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, logging from ...utils import auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_sew import SEWConfig from .configuration_sew import SEWConfig
@ -293,6 +294,7 @@ class SEWAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -300,7 +302,7 @@ class SEWAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -321,42 +323,9 @@ class SEWAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -378,7 +347,7 @@ class SEWAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class SEWFeedForward(nn.Module): class SEWFeedForward(nn.Module):

View File

@ -595,7 +595,14 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
""" """
_, patch_size, _ = image_hidden_states.shape _, patch_size, _ = image_hidden_states.shape
image_mask = input_ids == self.image_token_id if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask[..., 0] # slice off the hidden dim
else:
image_mask = input_ids == self.config.image_token_id
num_image_tokens = image_mask.sum(dim=1) num_image_tokens = image_mask.sum(dim=1)
if not torch.all(num_image_tokens % patch_size == 0): if not torch.all(num_image_tokens % patch_size == 0):
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.") raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
@ -717,14 +724,8 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache:
if past_key_values is None:
past_key_values = DynamicCache() past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
@ -732,12 +733,13 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
# START VISUAL INPUTS INTEGRATION # START VISUAL INPUTS INTEGRATION
if pixel_values is not None and image_hidden_states is not None: if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if inputs_embeds is not None and image_hidden_states is not None: if pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(
@ -996,27 +998,11 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
**kwargs, **kwargs,
) )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step if image_hidden_states is not None or cache_position[0] != 0:
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs["input_ids"] = input_ids
if image_hidden_states is not None:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None model_inputs["pixel_attention_mask"] = None
return model_inputs return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
__all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"] __all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"]

View File

@ -180,7 +180,14 @@ class SmolVLMModel(Idefics3Model):
): ):
_, patch_size, _ = image_hidden_states.shape _, patch_size, _ = image_hidden_states.shape
image_mask = input_ids == self.image_token_id if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask[..., 0] # slice off the hidden dim
else:
image_mask = input_ids == self.config.image_token_id
num_image_tokens = image_mask.sum(dim=1) num_image_tokens = image_mask.sum(dim=1)
if not torch.all(num_image_tokens % patch_size == 0): if not torch.all(num_image_tokens % patch_size == 0):
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.") raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
@ -296,14 +303,8 @@ class SmolVLMModel(Idefics3Model):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache:
if past_key_values is None:
past_key_values = DynamicCache() past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
@ -311,12 +312,13 @@ class SmolVLMModel(Idefics3Model):
# START VISUAL INPUTS INTEGRATION # START VISUAL INPUTS INTEGRATION
if pixel_values is not None and image_hidden_states is not None: if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if inputs_embeds is not None and image_hidden_states is not None: if pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(

View File

@ -205,7 +205,7 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->Speech2Text
class Speech2TextAttention(nn.Module): class Speech2TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch Splinter model.""" """PyTorch Splinter model."""
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -25,13 +24,19 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_outputs import (
from ...modeling_utils import PreTrainedModel BaseModelOutput,
ModelOutput,
QuestionAnsweringModelOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
auto_docstring, auto_docstring,
can_return_tuple,
logging, logging,
) )
from ...utils.deprecation import deprecate_kwarg
from .configuration_splinter import SplinterConfig from .configuration_splinter import SplinterConfig
@ -64,7 +69,6 @@ class SplinterEmbeddings(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: Optional[int] = 0,
) -> tuple: ) -> tuple:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
@ -74,7 +78,7 @@ class SplinterEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
@ -92,9 +96,37 @@ class SplinterEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter # Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter
class SplinterSelfAttention(nn.Module): class SplinterSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
@ -102,6 +134,7 @@ class SplinterSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})" f"heads ({config.num_attention_heads})"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -111,20 +144,12 @@ class SplinterSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr( self.attention_dropout = config.attention_probs_dropout_prob
config, "position_embedding_type", "absolute" self.scaling = self.attention_head_size**-0.5
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -134,96 +159,33 @@ class SplinterSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# and values come from an encoder; the attention mask needs to be key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
# such that the encoder's padding tokens are not attended to. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None: attention_interface: Callable = eager_attention_forward
# reuse k,v, cross_attentions if self.config._attn_implementation != "eager":
key_layer = past_key_value[0] attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer) attn_output, attn_weights = attention_interface(
self,
use_cache = past_key_value is not None query_states,
if self.is_decoder: key_states,
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. value_states,
# Further calls to cross_attention layer can then reuse all cross-attention attention_mask,
# key/value_states (first "if" case) dropout=0.0 if not self.training else self.attention_dropout,
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of scaling=self.scaling,
# all previous decoder key/value_states. Further calls to uni-directional self-attention head_mask=head_mask,
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) **kwargs,
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
) )
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attn_output = attn_output.reshape(*input_shape, -1).contiguous()
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in SplinterModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
@ -242,18 +204,11 @@ class SplinterSelfOutput(nn.Module):
return hidden_states return hidden_states
SPLINTER_SELF_ATTENTION_CLASSES = { # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter
"eager": SplinterSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
class SplinterAttention(nn.Module): class SplinterAttention(nn.Module):
def __init__(self, config, position_embedding_type=None): def __init__(self, config):
super().__init__() super().__init__()
self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation]( self.self = SplinterSelfAttention(config)
config, position_embedding_type=position_embedding_type
)
self.output = SplinterSelfOutput(config) self.output = SplinterSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
@ -275,6 +230,9 @@ class SplinterAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -284,15 +242,14 @@ class SplinterAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
encoder_hidden_states, output_attentions=output_attentions,
encoder_attention_mask, **kwargs,
past_key_value,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -330,22 +287,19 @@ class SplinterOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter
class SplinterLayer(GradientCheckpointingLayer): class SplinterLayer(GradientCheckpointingLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = SplinterAttention(config) self.attention = SplinterAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
self.intermediate = SplinterIntermediate(config) self.intermediate = SplinterIntermediate(config)
self.output = SplinterOutput(config) self.output = SplinterOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -355,60 +309,23 @@ class SplinterLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, hidden_states,
attention_mask, attention_mask=attention_mask,
head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
past_key_value=self_attn_past_key_value, **kwargs,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
) )
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
@ -417,14 +334,19 @@ class SplinterLayer(GradientCheckpointingLayer):
return layer_output return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter
class SplinterEncoder(nn.Module): class SplinterEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -437,65 +359,36 @@ class SplinterEncoder(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: **kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states=hidden_states,
attention_mask, attention_mask=attention_mask,
layer_head_mask, head_mask=layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
**kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: return BaseModelOutput(
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
@ -554,6 +447,11 @@ class SplinterModel(SplinterPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -570,7 +468,7 @@ class SplinterModel(SplinterPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[tuple, BaseModelOutput]:
r""" r"""
token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
@ -592,11 +490,6 @@ class SplinterModel(SplinterPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -610,11 +503,8 @@ class SplinterModel(SplinterPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
@ -622,17 +512,6 @@ class SplinterModel(SplinterPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
@ -645,31 +524,21 @@ class SplinterModel(SplinterPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=True,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
if not return_dict: return BaseModelOutput(
return (sequence_output,) + encoder_outputs[1:]
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )

View File

@ -725,8 +725,8 @@ class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1) matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1) matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + encoded_keypoints[1] all_hidden_states = all_hidden_states + encoded_keypoints[1]

View File

@ -435,11 +435,6 @@ class SwinSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -448,11 +443,11 @@ class SwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states) hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states)) query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer) value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

View File

@ -45,6 +45,7 @@ from ...modeling_outputs import (
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_unispeech import UniSpeechConfig from .configuration_unispeech import UniSpeechConfig
@ -332,6 +333,7 @@ class UniSpeechAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -339,7 +341,7 @@ class UniSpeechAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -360,42 +362,9 @@ class UniSpeechAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -417,7 +386,7 @@ class UniSpeechAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class UniSpeechFeedForward(nn.Module): class UniSpeechFeedForward(nn.Module):

View File

@ -47,6 +47,7 @@ from ...modeling_outputs import (
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_unispeech_sat import UniSpeechSatConfig from .configuration_unispeech_sat import UniSpeechSatConfig
@ -337,6 +338,7 @@ class UniSpeechSatAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -344,7 +346,7 @@ class UniSpeechSatAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -365,42 +367,9 @@ class UniSpeechSatAttention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -422,7 +391,7 @@ class UniSpeechSatAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class UniSpeechSatFeedForward(nn.Module): class UniSpeechSatFeedForward(nn.Module):

View File

@ -328,12 +328,6 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
"time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -343,10 +337,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
vision_feature_layer=vision_feature_layer, vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -359,10 +361,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
) )
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] * video_features.shape[1] n_video_features = video_features.shape[0] * video_features.shape[1]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -233,11 +233,6 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -246,10 +241,18 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -136,11 +136,6 @@ class VipLlavaModel(LlavaModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -149,10 +144,18 @@ class VipLlavaModel(LlavaModel):
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -55,6 +55,7 @@ from ...utils import (
is_torch_flex_attn_available, is_torch_flex_attn_available,
logging, logging,
) )
from ...utils.deprecation import deprecate_kwarg
from .configuration_wav2vec2 import Wav2Vec2Config from .configuration_wav2vec2 import Wav2Vec2Config
@ -524,6 +525,7 @@ class Wav2Vec2Attention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -531,7 +533,7 @@ class Wav2Vec2Attention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None, past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs # TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
@ -552,42 +554,9 @@ class Wav2Vec2Attention(nn.Module):
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj current_states = key_value_states if is_cross_attention else hidden_states
# `past_key_value[0].shape[2] == key_value_states.shape[1]` key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# is checking that the `sequence_length` of the `past_key_value` is the same as value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
@ -609,7 +578,7 @@ class Wav2Vec2Attention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, None
class Wav2Vec2FeedForward(nn.Module): class Wav2Vec2FeedForward(nn.Module):

View File

@ -30,6 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
auto_docstring, auto_docstring,
can_return_tuple,
logging, logging,
torch_int, torch_int,
) )
@ -576,6 +577,7 @@ class XCLIPEncoder(nn.Module):
self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -642,8 +644,6 @@ class XCLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
) )

View File

@ -1642,7 +1642,6 @@ def set_model_tester_for_less_flaky_test(test_case):
"AriaVisionText2TextModelTester", "AriaVisionText2TextModelTester",
"GPTNeoModelTester", "GPTNeoModelTester",
"DPTModelTester", "DPTModelTester",
"Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
] ]
if test_case.model_tester.__class__.__name__ in exceptional_classes: if test_case.model_tester.__class__.__name__ in exceptional_classes:
target_num_hidden_layers = None target_num_hidden_layers = None

View File

@ -118,27 +118,6 @@ from unittest.mock import patch
from transformers.utils import is_sklearn_available from transformers.utils import is_sklearn_available
# TODO: raushan remove this when VLMs start accepting input embeds
VLM_CLASS_NAMES = [
"llava",
"idefics2",
"idefics3",
"mllama",
"paligemma",
"emu3",
"gotocr2",
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"janus",
"gemma3",
"mistral3",
"chameleon",
"internvl",
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`
]
class GenerationTesterMixin: class GenerationTesterMixin:
input_name = "input_ids" input_name = "input_ids"
model_tester = None model_tester = None
@ -1228,7 +1207,23 @@ class GenerationTesterMixin:
"blip2", # overridden `generate()` "blip2", # overridden `generate()`
"instructblip", "instructblip",
"instructblipvideo", "instructblipvideo",
*VLM_CLASS_NAMES, # shouldn't suggest image tokens # All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
"llava",
"idefics2",
"idefics3",
"mllama",
"paligemma",
"emu3",
"gotocr2",
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"janus",
"gemma3",
"mistral3",
"chameleon",
"internvl",
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`,
] ]
): ):
self.skipTest(reason="May fix in the future: need model-specific fixes") self.skipTest(reason="May fix in the future: need model-specific fixes")
@ -1641,6 +1636,58 @@ class GenerationTesterMixin:
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
@pytest.mark.generate
def test_generate_from_random_inputs_embeds(self):
"""
Text-only: Tests that different `inputs_embeds` generate different outputs in models with `main_input=="input_ids"`.
Some models have 'images' as main input and thus can't generate with random text embeddings.
See `test_generate_from_inputs_embeds` for more general checks.
"""
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if config.is_encoder_decoder:
continue
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
# No easy fix, let's skip the test for now
has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi"]
)
if model_class.main_input_name != "input_ids" or has_complex_embeds_computation:
self.skipTest(
"The model's main input name in not `input_ids` and we need kwargs from input dict as well."
)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"do_sample": False,
"max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens
}
input_ids = inputs_dict.pop("input_ids")
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, **generation_kwargs)
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids=input_ids, inputs_embeds=random_embeds, **generation_kwargs
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
@pytest.mark.generate @pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)]) @parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams): def test_generate_from_inputs_embeds(self, _, num_beams):
@ -1662,34 +1709,22 @@ class GenerationTesterMixin:
continue continue
# There are a few exception patterns in this test: # There are a few exception patterns in this test:
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed # 1 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"])
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
# than calling the embedding layer with `input_ids`. Subcases of this exception: # than calling the embedding layer with `input_ids`. Subcases of this exception:
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag) # 1.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
if hasattr(config, "scale_embedding"): if hasattr(config, "scale_embedding"):
config.scale_embedding = False config.scale_embedding = False
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
# HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive; # HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive;
# this is similar to VLMs and should likely be standardized for similar audio models in the future, # this is similar to VLMs and should likely be standardized for similar audio models in the future,
# then made generic here. # then made generic here.
if "granitespeech" in model_class.__name__.lower(): if "granitespeech" in model_class.__name__.lower():
inputs_dict.pop("input_features", None) inputs_dict.pop("input_features", None)
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` # 1.B - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
has_complex_embeds_computation = any( has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi"] model_name in model_class.__name__.lower() for model_name in ["moshi"]
) )
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, # 2 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input. # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
missing_attention_mask = "attention_mask" not in inputs_dict missing_attention_mask = "attention_mask" not in inputs_dict
@ -1702,31 +1737,23 @@ class GenerationTesterMixin:
"do_sample": False, "do_sample": False,
"max_new_tokens": 5, "max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens "min_new_tokens": 5, # generate exactly 5 tokens
"use_cache": True,
} }
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict) outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5)) self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output). # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same. # The output of the two calls should be the same.
inputs_embeds = model.get_input_embeddings()(input_ids) inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate( outputs_from_embeds = model.generate(
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
) )
if not has_complex_embeds_computation: if not has_complex_embeds_computation:
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same # be the same
if not (requires_inputs_ids or missing_attention_mask): if not missing_attention_mask:
outputs_from_embeds_wo_ids = model.generate( outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
) )
@ -1753,17 +1780,6 @@ class GenerationTesterMixin:
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
self.skipTest(reason="This model does not support `inputs_embeds` in generation") self.skipTest(reason="This model does not support `inputs_embeds` in generation")
# Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
input_ids = inputs_dict.pop("input_ids") input_ids = inputs_dict.pop("input_ids")
model.config.use_cache = True model.config.use_cache = True
@ -1925,14 +1941,6 @@ class GenerationTesterMixin:
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
input_ids = inputs_dict.pop("input_ids") input_ids = inputs_dict.pop("input_ids")
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1

View File

@ -297,7 +297,7 @@ class AltCLIPTextModelTester:
@require_torch @require_torch
class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (AltCLIPTextModel,) if is_torch_available() else () all_model_classes = (AltCLIPTextModel,) if is_torch_available() else ()
fx_compatible = True fx_compatible = False # Cannot support if `can_return_tuple`
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
@ -411,7 +411,7 @@ def prepare_img():
class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (AltCLIPModel,) if is_torch_available() else () all_model_classes = (AltCLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": AltCLIPModel} if is_torch_available() else {} pipeline_model_mapping = {"feature-extraction": AltCLIPModel} if is_torch_available() else {}
fx_compatible = True fx_compatible = False # Cannot support if `can_return_tuple`
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False

View File

@ -189,49 +189,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
self.model_tester = AriaVisionText2TextModelTester(self) self.model_tester = AriaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip( @unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )
@ -270,14 +227,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
def test_dola_decoding_sample(self): def test_dola_decoding_sample(self):
pass pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_0_greedy(self):
pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
@unittest.skip(reason="Dynamic control flow due to MoE") @unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_with_static_cache(self): def test_generate_with_static_cache(self):
pass pass

View File

@ -62,7 +62,7 @@ class AyaVisionVisionText2TextModelTester:
bos_token_id=0, bos_token_id=0,
eos_token_id=0, eos_token_id=0,
pad_token_id=0, pad_token_id=0,
image_token_index=1, image_token_index=2,
num_channels=3, num_channels=3,
image_size=64, image_size=64,
model_type="aya_vision", model_type="aya_vision",
@ -183,49 +183,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip("Failing because of unique cache (HybridCache)") @unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs): def test_model_outputs_equivalence(self, **kwargs):
pass pass
@ -285,10 +242,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)") @unittest.skip("Failing because of unique cache (HybridCache)")
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass

View File

@ -20,7 +20,6 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
import requests import requests
from parameterized import parameterized
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
@ -674,15 +673,6 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
# They should result in very similar logits # They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py # this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
class Blip2TextModelTester: class Blip2TextModelTester:

Some files were not shown because too many files have changed in this diff Show More