diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bb15454c7f5..de6ae44bb5a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -733,7 +733,9 @@ class GenerationMixin(ContinuousMixin): # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. if input_name == "input_ids" and "inputs_embeds" in model_kwargs: - if not self.config.is_encoder_decoder: + if model_kwargs["inputs_embeds"] is None: + model_kwargs.pop("inputs_embeds") + elif not self.config.is_encoder_decoder: has_inputs_embeds_forwarding = "inputs_embeds" in set( inspect.signature(self.prepare_inputs_for_generation).parameters.keys() ) @@ -748,10 +750,11 @@ class GenerationMixin(ContinuousMixin): model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( inputs, bos_token_id, model_kwargs=model_kwargs ) + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" else: if inputs is not None: raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 0ffce54977f..61c20ecf6db 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import math import operator import os import re @@ -280,7 +281,48 @@ def repack_weights( def get_tensor_shard(param, empty_param, device_mesh, rank, dim): """ Generalized tensor sharding across a multi-dimensional device mesh. + Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`. + Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics. + `Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases + such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases. + Case (1) + empty_param (16, 5120, 8190) + dim 0 + device_mesh.size() 4 + rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190) + rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190) + rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190) + rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190) + + Case (2) + empty_param (16, 5120, 8190) + dim 0 + device_mesh.size() 14 + rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190) + rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190) + rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190) + rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190) + rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190) + rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190) + rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190) + rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190) + rank 8 gets (0, 5120, 8190) + rank 9 gets (0, 5120, 8190) + rank 10 gets (0, 5120, 8190) + rank 11 gets (0, 5120, 8190) + rank 12 gets (0, 5120, 8190) + rank 13 gets (0, 5120, 8190) + + Case (3) + empty_param (16, 5120, 8190) + dim 0 + device_mesh.size() 3 + rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190) + rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190) + rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190) + + In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly. Args: param (torch.Tensor): The tensor to shard. empty_param (torch.Tensor): A tensor used for shape reference. @@ -289,6 +331,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim): dim (int): Dimension along which to shard the tensor. """ param_dim = empty_param.dim() + if dim < 0: dim = param_dim + dim if dim >= param_dim: @@ -301,15 +344,18 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim): if rank >= world_size: raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}") - shard_size = empty_param.shape[dim] // world_size + shard_size = math.ceil(empty_param.shape[dim] / world_size) start = rank * shard_size - end = start + shard_size # Construct slicing index dynamically + end = min(start + shard_size, empty_param.shape[dim]) slice_indices = [slice(None)] * param_dim - slice_indices[dim] = slice(start, end) - - return param[tuple(slice_indices)] + if start < empty_param.shape[dim]: + slice_indices[dim] = slice(start, end) + return param[tuple(slice_indices)] + dimensions = list(param.shape) + dimensions[dim] = 0 + return torch.empty(tuple(dimensions), dtype=torch.int64) def distribute_module( @@ -500,7 +546,9 @@ class ColwiseParallel(TensorParallelLayer): if to_contiguous: parameter = parameter.contiguous() if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + parameter = DTensor.from_local( + parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride() + ) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod @@ -574,7 +622,9 @@ class RowwiseParallel(TensorParallelLayer): if to_contiguous: parameter = parameter.contiguous() if self.use_dtensor: - parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + parameter = DTensor.from_local( + parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride() + ) return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 649447ca8f7..13da327dab0 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -508,6 +508,22 @@ def _flash_attention_forward( query_states, key_states, value_states, target_dtype ) + # We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + # under two cases: + # Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to + # use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility + # to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information + is_fa2_with_position_ids = ( + position_ids is not None + and query_states.shape[0] == 1 + and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) + ) + is_fa2_with_varlen_kwargs = all( + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) + ) + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -531,14 +547,7 @@ def _flash_attention_forward( ) attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) - # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif ( - position_ids is not None - and query_states.shape[0] == 1 - and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) - ): + elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids: batch_size = query_states.size(0) if cu_seq_lens_q is None or cu_seq_lens_k is None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 515fb6d3811..e6b7031ab37 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi module_map[name + f".{key}"] = module state_dict = model_to_save.state_dict() - if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS): + if any( + allowed_name in class_name.__name__.lower() + for class_name in self.__class__.__mro__[:-1] + for allowed_name in VLMS + ): reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()} original_state_dict = {} @@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi key_mapping = kwargs.pop("key_mapping", None) # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model - if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS): + if key_mapping is None and any( + allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS + ): key_mapping = cls._checkpoint_conversion_mapping # Not used anymore -- remove them from the kwargs @@ -5837,7 +5843,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, else None ) total_byte_count = defaultdict(lambda: 0) + tied_param_names = _get_tied_weight_keys(model) for param_name, device in accelerator_device_map.items(): + # Skip if the parameter has already been accounted for (tied weights) + if param_name in tied_param_names: + continue + param = model.get_parameter_or_buffer(param_name) # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` param_byte_count = param.numel() * param.element_size() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index f608eab3de3..da015bf7dd2 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.utils.checkpoint @@ -25,14 +25,15 @@ from torch import nn from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( + BaseModelOutput, BaseModelOutputWithNoAttention, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, + BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndNoAttention, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, auto_docstring, logging +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig @@ -90,7 +91,7 @@ class AlignOutput(ModelOutput): The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`]. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The output of [`AlignVisionModel`]. - text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`): + text_model_output (`BaseModelOutputWithPooling`): The output of the [`AlignTextModel`]. vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`): The output of the [`AlignVisionModel`]. @@ -101,7 +102,7 @@ class AlignOutput(ModelOutput): logits_per_text: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None - text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None + text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None def to_tuple(self) -> tuple[Any]: @@ -508,7 +509,6 @@ class AlignVisionEncoder(nn.Module): ) -# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText class AlignTextEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -537,7 +537,6 @@ class AlignTextEmbeddings(nn.Module): token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() @@ -547,7 +546,7 @@ class AlignTextEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = self.position_ids[:, :seq_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -573,9 +572,35 @@ class AlignTextEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class AlignTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -583,6 +608,7 @@ class AlignTextSelfAttention(nn.Module): f"heads ({config.num_attention_heads})" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -592,20 +618,12 @@ class AlignTextSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -615,96 +633,33 @@ class AlignTextSelfAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_layer = self.transpose_for_scores(mixed_query_layer) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -723,18 +678,10 @@ class AlignTextSelfOutput(nn.Module): return hidden_states -ALIGN_TEXT_SELF_ATTENTION_CLASSES = { - "eager": AlignTextSelfAttention, -} - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT class AlignTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = AlignTextSelfAttention(config) self.output = AlignTextSelfOutput(config) self.pruned_heads = set() @@ -756,6 +703,9 @@ class AlignTextAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -765,15 +715,14 @@ class AlignTextAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -811,22 +760,18 @@ class AlignTextOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText class AlignTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = AlignTextAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = AlignTextAttention(config, position_embedding_type="absolute") self.intermediate = AlignTextIntermediate(config) self.output = AlignTextOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -836,60 +781,23 @@ class AlignTextLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -898,14 +806,18 @@ class AlignTextLayer(GradientCheckpointingLayer): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText class AlignTextEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -918,65 +830,36 @@ class AlignTextEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -1052,6 +935,7 @@ class AlignTextModel(AlignPreTrainedModel): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value + @can_return_tuple @auto_docstring def forward( self, @@ -1059,12 +943,13 @@ class AlignTextModel(AlignPreTrainedModel): attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + **kwargs, + ) -> Union[tuple, BaseModelOutputWithPooling]: r""" Examples: @@ -1133,20 +1018,17 @@ class AlignTextModel(AlignPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, + **kwargs, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) @@ -1180,6 +1062,7 @@ class AlignVisionModel(AlignPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.convolution + @can_return_tuple @auto_docstring def forward( self, @@ -1219,7 +1102,7 @@ class AlignVisionModel(AlignPreTrainedModel): encoder_outputs = self.encoder( embedding_output, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) # Apply pooling last_hidden_state = encoder_outputs[0] @@ -1227,9 +1110,6 @@ class AlignVisionModel(AlignPreTrainedModel): # Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim) pooled_output = pooled_output.reshape(pooled_output.shape[:2]) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPoolingAndNoAttention( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1369,6 +1249,7 @@ class AlignModel(AlignPreTrainedModel): return image_features + @can_return_tuple @auto_docstring def forward( self, @@ -1419,7 +1300,7 @@ class AlignModel(AlignPreTrainedModel): vision_outputs = self.vision_model( pixel_values=pixel_values, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) text_outputs = self.text_model( @@ -1431,7 +1312,7 @@ class AlignModel(AlignPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) image_embeds = vision_outputs[1] @@ -1450,10 +1331,6 @@ class AlignModel(AlignPreTrainedModel): if return_loss: loss = align_loss(logits_per_text) - if not return_dict: - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - return AlignOutput( loss=loss, logits_per_image=logits_per_image, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 8f6f0ff7fbc..c770dd5adce 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -26,14 +26,14 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndProjection, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, auto_docstring, logging, torch_int +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -180,7 +180,6 @@ class AltRobertaEmbeddings(nn.Module): return position_ids.unsqueeze(0).expand(input_shape) -# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta class AltRobertaSelfAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -206,13 +205,9 @@ class AltRobertaSelfAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -223,55 +218,19 @@ class AltRobertaSelfAttention(nn.Module): past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r @@ -310,8 +269,6 @@ class AltRobertaSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - if self.is_decoder: - outputs = outputs + (past_key_value,) return outputs @@ -335,7 +292,6 @@ ALT_ROBERTA_SELF_ATTENTION_CLASSES = { } -# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA class AltRobertaAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() @@ -363,6 +319,9 @@ class AltRobertaAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -375,12 +334,9 @@ class AltRobertaAttention(nn.Module): ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -418,22 +374,19 @@ class AltRobertaOutput(nn.Module): return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta +# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta class AltRobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = AltRobertaAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute") self.intermediate = AltRobertaIntermediate(config) self.output = AltRobertaOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -443,60 +396,23 @@ class AltRobertaLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -505,14 +421,19 @@ class AltRobertaLayer(GradientCheckpointingLayer): return layer_output -# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta class AltRobertaEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -525,65 +446,36 @@ class AltRobertaEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -787,6 +679,7 @@ class AltCLIPEncoder(nn.Module): self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -853,8 +746,6 @@ class AltCLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) @@ -1008,6 +899,7 @@ class AltCLIPVisionTransformer(nn.Module): self.encoder = AltCLIPEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + @can_return_tuple @auto_docstring def forward( self, @@ -1033,16 +925,13 @@ class AltCLIPVisionTransformer(nn.Module): inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) last_hidden_state = encoder_outputs[0] pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1106,16 +995,11 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel): @auto_docstring( custom_intro=""" - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is + The model behaves as an encoder following the architecture described in *Attention is all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set - to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and - `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. - - .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 + .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 """ ) class AltRobertaModel(AltCLIPPreTrainedModel): @@ -1152,6 +1036,10 @@ class AltRobertaModel(AltCLIPPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") @auto_docstring # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( @@ -1176,11 +1064,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -1194,11 +1077,8 @@ class AltRobertaModel(AltCLIPPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones(((batch_size, seq_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -1212,21 +1092,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel): # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) @@ -1235,33 +1100,23 @@ class AltRobertaModel(AltCLIPPreTrainedModel): position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) @@ -1284,6 +1139,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel): def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: return super().resize_token_embeddings(new_num_tokens) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -1326,11 +1184,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) # last module outputs @@ -1343,9 +1199,6 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel): projection_state = self.transformation(sequence_output) pooler_output = projection_state[:, 0] - if not return_dict: - return (projection_state, pooler_output) + outputs[2:4] - return BaseModelOutputWithPoolingAndProjection( last_hidden_state=projection_state, pooler_output=pooler_output, diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index c224c4300eb..da233918123 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -225,7 +225,7 @@ class ArceeAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 87f11d19269..00b912c5b27 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -532,7 +532,7 @@ class AriaTextAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -1113,11 +1113,12 @@ class AriaModel(AriaPreTrainedModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - image_embeds = input_ids == self.config.image_token_id - special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_features = self.get_image_features( pixel_values=pixel_values, pixel_mask=pixel_mask, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index b5c18a40b71..a40041a82bb 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1446,11 +1446,12 @@ class AriaModel(LlavaModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - image_embeds = input_ids == self.config.image_token_id - special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_features = self.get_image_features( pixel_values=pixel_values, pixel_mask=pixel_mask, diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index a55986b1631..4983a0dcca7 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -302,14 +302,14 @@ class AyaVisionModel(AyaVisionPreTrainedModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index ad5c1e58d43..93e7e3184a1 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -223,14 +223,14 @@ class AyaVisionModel(LlavaModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 12c6e52c65b..698a74449ba 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -296,7 +296,7 @@ class BambaAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 8d26aee4e56..c81b990b7de 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1855,6 +1855,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) + _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): @@ -1971,10 +1972,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): def forward( self, pixel_values: torch.FloatTensor, - input_ids: torch.FloatTensor, + input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, @@ -2066,14 +2068,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): language_model_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) # if the model already has "image_token_id" then the input is expanded to account for image embeds - # otherwise we expand manually by concating + # otherwise we expand manually by concatenating if getattr(self.config, "image_token_id", None) is not None: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: @@ -2146,6 +2159,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: @@ -2159,6 +2173,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the positional encoding of the image embeddings. Returns: captions (list): A list of strings of length batch_size * num_captions. @@ -2193,22 +2211,32 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - if input_ids is None: - start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "image_token_id", None) is not None: - start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) - input_ids = input_ids.repeat(batch_size, 1) + if inputs_embeds is None: + if input_ids is None: + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "image_token_id", None) is not None: + start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) - inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "image_token_id", None) is not None: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for image tokens in BLIP-2 should be done in processing. " diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 9b4225655f5..1635c7b3d45 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1026,7 +1026,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) @auto_docstring - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index e65f0166d7f..dd8ecf2c9eb 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -26,13 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, auto_docstring, logging +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bros import BrosConfig @@ -150,7 +151,6 @@ class BrosTextEmbeddings(nn.Module): token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, - past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() @@ -160,7 +160,7 @@ class BrosTextEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = self.position_ids[:, :seq_length] if token_type_ids is None: if hasattr(self, "token_type_ids"): @@ -208,14 +208,7 @@ class BrosSelfAttention(nn.Module): self.is_decoder = config.is_decoder - def transpose_for_scores(self, x: torch.Tensor): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -227,42 +220,21 @@ class BrosSelfAttention(nn.Module): past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[torch.Tensor] = False, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] + if is_cross_attention: + key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2) attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -317,7 +289,7 @@ class BrosSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (None,) return outputs @@ -364,6 +336,7 @@ class BrosAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -382,7 +355,6 @@ class BrosAttention(nn.Module): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, output_attentions=output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -435,6 +407,7 @@ class BrosLayer(GradientCheckpointingLayer): self.intermediate = BrosIntermediate(config) self.output = BrosOutput(config) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -446,50 +419,38 @@ class BrosLayer(GradientCheckpointingLayer): past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, bbox_pos_emb=bbox_pos_emb, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if hasattr(self, "crossattention"): raise Exception( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, @@ -500,7 +461,7 @@ class BrosLayer(GradientCheckpointingLayer): # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (None,) return outputs @@ -516,6 +477,9 @@ class BrosEncoder(nn.Module): self.config = config self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)]) + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -529,33 +493,28 @@ class BrosEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - bbox_pos_emb, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing + hidden_states=hidden_states, + bbox_pos_emb=bbox_pos_emb, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -564,21 +523,8 @@ class BrosEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutputWithCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -689,6 +635,9 @@ class BrosModel(BrosPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -736,11 +685,6 @@ class BrosModel(BrosPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -756,9 +700,6 @@ class BrosModel(BrosPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) @@ -797,7 +738,6 @@ class BrosModel(BrosPreTrainedModel): position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, ) # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token @@ -813,22 +753,16 @@ class BrosModel(BrosPreTrainedModel): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, @@ -852,6 +786,7 @@ class BrosForTokenClassification(BrosPreTrainedModel): self.init_weights() + @can_return_tuple @auto_docstring def forward( self, @@ -908,7 +843,7 @@ class BrosForTokenClassification(BrosPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -927,10 +862,6 @@ class BrosForTokenClassification(BrosPreTrainedModel): else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -976,6 +907,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel): self.init_weights() + @can_return_tuple @auto_docstring def forward( self, @@ -1037,7 +969,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) last_hidden_states = outputs[0] @@ -1082,10 +1014,6 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel): loss = initial_token_loss + subsequent_token_loss - if not return_dict: - output = (initial_token_logits, subsequent_token_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return BrosSpadeOutput( loss=loss, initial_token_logits=initial_token_logits, @@ -1118,6 +1046,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel): self.init_weights() + @can_return_tuple @auto_docstring def forward( self, @@ -1173,7 +1102,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) last_hidden_states = outputs[0] @@ -1203,10 +1132,6 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel): loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask]) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 0aaf197dbe8..010fe244de9 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -963,25 +963,28 @@ class ChameleonModel(ChameleonPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): - n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() - n_image_features = image_tokens.shape[0] * image_tokens.shape[1] + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens_in_text = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + + image_embeds = self.get_image_features(pixel_values) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel(): + n_image_features = image_embeds.shape[0] * image_embeds.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" ) - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # torch.jit.trace() doesn't support cache objects in the output if use_cache and past_key_values is None and not torch.jit.is_tracing(): diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 2ea9225b552..6ab3ade7c25 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -14,9 +14,8 @@ # limitations under the License. """PyTorch Chinese-CLIP model.""" -import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.utils.checkpoint @@ -26,13 +25,13 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import ModelOutput, auto_docstring, logging, torch_int +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig @@ -90,7 +89,7 @@ class ChineseCLIPOutput(ModelOutput): ) -# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText +# Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP class ChineseCLIPTextEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -119,7 +118,6 @@ class ChineseCLIPTextEmbeddings(nn.Module): token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() @@ -129,7 +127,7 @@ class ChineseCLIPTextEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = self.position_ids[:, :seq_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -239,9 +237,37 @@ class ChineseCLIPVisionEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP class ChineseCLIPTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -249,6 +275,7 @@ class ChineseCLIPTextSelfAttention(nn.Module): f"heads ({config.num_attention_heads})" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -258,20 +285,12 @@ class ChineseCLIPTextSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -281,96 +300,33 @@ class ChineseCLIPTextSelfAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_layer = self.transpose_for_scores(mixed_query_layer) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -389,18 +345,11 @@ class ChineseCLIPTextSelfOutput(nn.Module): return hidden_states -CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = { - "eager": ChineseCLIPTextSelfAttention, -} - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT +# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP class ChineseCLIPTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = ChineseCLIPTextSelfAttention(config) self.output = ChineseCLIPTextSelfOutput(config) self.pruned_heads = set() @@ -422,6 +371,9 @@ class ChineseCLIPTextAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -431,15 +383,14 @@ class ChineseCLIPTextAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -468,66 +419,37 @@ class ChineseCLIPVisionAttention(nn.Module): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( - self, - hidden_states: torch.Tensor, - output_attentions: Optional[bool] = False, + self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # get query proj - query_states = self.q_proj(hidden_states) * self.scale - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit akward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + None, + dropout=0.0 if not self.training else self.dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped + return attn_output, attn_weights # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText @@ -577,22 +499,19 @@ class ChineseCLIPVisionMLP(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText +# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP class ChineseCLIPTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ChineseCLIPTextAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute") self.intermediate = ChineseCLIPTextIntermediate(config) self.output = ChineseCLIPTextOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -602,60 +521,23 @@ class ChineseCLIPTextLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -777,14 +659,19 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): module.bias.data.zero_() -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP class ChineseCLIPTextEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -797,65 +684,36 @@ class ChineseCLIPTextEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -874,6 +732,7 @@ class ChineseCLIPVisionEncoder(nn.Module): self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -922,8 +781,6 @@ class ChineseCLIPVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) @@ -940,6 +797,7 @@ class ChineseCLIPVisionTransformer(nn.Module): self.encoder = ChineseCLIPVisionEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + @can_return_tuple @auto_docstring def forward( self, @@ -965,16 +823,13 @@ class ChineseCLIPVisionTransformer(nn.Module): inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) last_hidden_state = encoder_outputs[0] pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -1034,6 +889,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @can_return_tuple @auto_docstring def forward( self, @@ -1050,18 +906,13 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -1093,56 +944,28 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) @@ -1343,6 +1166,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel): return image_features + @can_return_tuple @auto_docstring def forward( self, @@ -1392,7 +1216,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, + return_dict=True, ) text_outputs = self.text_model( @@ -1402,7 +1226,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel): position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) image_embeds = vision_outputs[1] @@ -1424,14 +1248,6 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel): if return_loss: loss = chinese_clip_loss(logits_per_text) - if not return_dict: - # fix the None pooled_output of text_outputs to conform with dict_output - pooled_output = text_outputs[1] - if pooled_output is None: - text_outputs = (text_outputs[0],) + text_outputs[2:] - output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) - return ((loss,) + output) if loss is not None else output - return ChineseCLIPOutput( loss=loss, logits_per_image=logits_per_image, diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 737dc6abad7..707c04d0586 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -17,7 +17,7 @@ import collections import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -26,13 +26,14 @@ from torch import nn from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer -from ...utils import ModelOutput, auto_docstring, logging, torch_int +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig @@ -399,11 +400,6 @@ class ClapAudioSelfAttention(nn.Module): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -412,11 +408,11 @@ class ClapAudioSelfAttention(nn.Module): output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) + hidden_shape = (batch_size, dim, -1, self.attention_head_size) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -1090,9 +1086,37 @@ class ClapTextEmbeddings(nn.Module): return position_ids.unsqueeze(0).expand(input_shape) -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap class ClapTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -1100,6 +1124,7 @@ class ClapTextSelfAttention(nn.Module): f"heads ({config.num_attention_heads})" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -1109,20 +1134,12 @@ class ClapTextSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -1132,96 +1149,33 @@ class ClapTextSelfAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_layer = self.transpose_for_scores(mixed_query_layer) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -1240,18 +1194,11 @@ class ClapTextSelfOutput(nn.Module): return hidden_states -CLAP_TEXT_SELF_ATTENTION_CLASSES = { - "eager": ClapTextSelfAttention, -} - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT +# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap class ClapTextAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = ClapTextSelfAttention(config) self.output = ClapTextSelfOutput(config) self.pruned_heads = set() @@ -1273,6 +1220,9 @@ class ClapTextAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -1282,15 +1232,14 @@ class ClapTextAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -1328,22 +1277,19 @@ class ClapTextOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText +# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap class ClapTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ClapTextAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ClapTextAttention(config, position_embedding_type="absolute") self.intermediate = ClapTextIntermediate(config) self.output = ClapTextOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -1353,60 +1299,23 @@ class ClapTextLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -1415,14 +1324,19 @@ class ClapTextLayer(GradientCheckpointingLayer): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap class ClapTextEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -1435,65 +1349,36 @@ class ClapTextEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -1643,6 +1528,11 @@ class ClapTextModel(ClapPreTrainedModel): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -1666,11 +1556,6 @@ class ClapTextModel(ClapPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -1684,11 +1569,8 @@ class ClapTextModel(ClapPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones(((batch_size, seq_length)), device=device) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -1702,21 +1584,6 @@ class ClapTextModel(ClapPreTrainedModel): # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) @@ -1725,33 +1592,23 @@ class ClapTextModel(ClapPreTrainedModel): position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) @@ -1892,6 +1749,7 @@ class ClapModel(ClapPreTrainedModel): return audio_features + @can_return_tuple @auto_docstring def forward( self, @@ -1947,7 +1805,7 @@ class ClapModel(ClapPreTrainedModel): is_longer=is_longer, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) text_outputs = self.text_model( @@ -1956,7 +1814,7 @@ class ClapModel(ClapPreTrainedModel): position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output @@ -1981,10 +1839,6 @@ class ClapModel(ClapPreTrainedModel): audio_loss = contrastive_loss(logits_per_audio.t()) loss = (caption_loss + audio_loss) / 2.0 - if not return_dict: - output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs) - return ((loss,) + output) if loss is not None else output - return ClapOutput( loss=loss, logits_per_audio=logits_per_audio, @@ -2013,6 +1867,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel): def set_input_embeddings(self, value): self.text_model.embeddings.word_embeddings = value + @can_return_tuple @auto_docstring def forward( self, @@ -2045,17 +1900,13 @@ class ClapTextModelWithProjection(ClapPreTrainedModel): position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output text_embeds = self.text_projection(pooled_output) - if not return_dict: - outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] - return tuple(output for output in outputs if output is not None) - return ClapTextModelOutput( text_embeds=text_embeds, last_hidden_state=text_outputs.last_hidden_state, @@ -2079,6 +1930,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.audio_model.audio_encoder.patch_embed.proj + @can_return_tuple @auto_docstring def forward( self, @@ -2123,17 +1975,13 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel): is_longer=is_longer, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output audio_embeds = self.audio_projection(pooled_output) - if not return_dict: - outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:] - return tuple(output for output in outputs if output is not None) - return ClapAudioModelOutput( audio_embeds=audio_embeds, last_hidden_state=audio_outputs.last_hidden_state, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 732712c517c..b6a12e6e636 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -28,7 +28,7 @@ from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import ModelOutput, auto_docstring, logging, torch_int +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig @@ -490,6 +490,7 @@ class CLIPSegEncoder(nn.Module): self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -555,8 +556,6 @@ class CLIPSegEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 19850c0dd8b..3b96c272c0d 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -311,7 +311,7 @@ class CsmAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index e4ddac37541..60509f419fb 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -45,6 +45,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available +from ...utils.deprecation import deprecate_kwarg from .configuration_data2vec_audio import Data2VecAudioConfig @@ -240,6 +241,7 @@ class Data2VecAudioAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -247,7 +249,7 @@ class Data2VecAudioAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -268,42 +270,9 @@ class Data2VecAudioAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -325,7 +294,7 @@ class Data2VecAudioAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class Data2VecAudioFeedForward(nn.Module): diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 97bca6d0d69..f447ff6258d 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -634,7 +634,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel): self.encoder.layer[layer].attention.prune_heads(heads) @auto_docstring - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 19cac3e8c3a..b87712852fd 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -281,7 +281,7 @@ class DiaSelfAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index a2dafed7405..7af6a3ad07d 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -405,11 +405,6 @@ class DonutSwinSelfAttention(nn.Module): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -418,11 +413,11 @@ class DonutSwinSelfAttention(nn.Module): output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) + hidden_shape = (batch_size, dim, -1, self.attention_head_size) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 3487138234b..422838603c0 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -189,7 +189,7 @@ class Emu3Attention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -1537,20 +1537,26 @@ class Emu3Model(Emu3PreTrainedModel): "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values, image_sizes) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + image_embeds = self.get_image_features(pixel_values, image_sizes) + image_embeds = torch.cat(image_embeds, dim=0) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.text_model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index fbcafff16ce..bfb22a96907 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -1033,20 +1033,26 @@ class Emu3Model(Emu3PreTrainedModel): "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values, image_sizes) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + image_embeds = self.get_image_features(pixel_values, image_sizes) + image_embeds = torch.cat(image_embeds, dim=0) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.text_model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 953a024a823..2f39fef5388 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -26,14 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_esm import EsmConfig @@ -187,12 +188,16 @@ class EsmEmbeddings(nn.Module): self.mask_token_id = config.mask_token_id def forward( - self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + self, + input_ids=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, ): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx) else: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) @@ -281,11 +286,7 @@ class EsmSelfAttention(nn.Module): self.is_decoder = config.is_decoder - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -296,32 +297,22 @@ class EsmSelfAttention(nn.Module): past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] + if is_cross_attention: + key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2) attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, @@ -329,16 +320,6 @@ class EsmSelfAttention(nn.Module): # ESM code and fix rotary embeddings. query_layer = query_layer * self.attention_head_size**-0.5 - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) @@ -385,7 +366,7 @@ class EsmSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (None,) return outputs @@ -418,6 +399,7 @@ class EsmFlashAttention2(EsmSelfAttention): self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() self.dropout_prob = config.attention_probs_dropout_prob + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -441,7 +423,6 @@ class EsmFlashAttention2(EsmSelfAttention): head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, output_attentions, ) @@ -450,9 +431,6 @@ class EsmFlashAttention2(EsmSelfAttention): query_layer = self.transpose_for_scores(self.query(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - if past_key_value is not None: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -514,7 +492,7 @@ class EsmFlashAttention2(EsmSelfAttention): outputs = (attn_output, None) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (None,) return outputs @@ -551,6 +529,7 @@ class EsmAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states, @@ -564,12 +543,11 @@ class EsmAttention(nn.Module): hidden_states_ln = self.LayerNorm(hidden_states) self_outputs = self.self( hidden_states_ln, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -616,6 +594,7 @@ class EsmLayer(GradientCheckpointingLayer): self.output = EsmOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states, @@ -626,25 +605,20 @@ class EsmLayer(GradientCheckpointingLayer): past_key_value=None, output_attentions=False, ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise AttributeError( @@ -652,31 +626,24 @@ class EsmLayer(GradientCheckpointingLayer): " with cross-attention layers by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - layer_output = self.feed_forward_chunk(attention_output) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (None,) return outputs def feed_forward_chunk(self, attention_output): @@ -694,6 +661,9 @@ class EsmEncoder(nn.Module): self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False + @deprecate_kwarg("past_key_value", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states, @@ -707,38 +677,26 @@ class EsmEncoder(nn.Module): output_hidden_states=False, return_dict=True, ): - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." - ) - use_cache = False all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -750,21 +708,8 @@ class EsmEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutputWithCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -863,6 +808,9 @@ class EsmModel(EsmPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -903,11 +851,6 @@ class EsmModel(EsmPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -921,11 +864,8 @@ class EsmModel(EsmPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones(((batch_size, seq_length)), device=device) if self.config._attn_implementation == "flash_attention_2": extended_attention_mask = attention_mask @@ -958,7 +898,6 @@ class EsmModel(EsmPreTrainedModel): position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, @@ -966,22 +905,16 @@ class EsmModel(EsmPreTrainedModel): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, @@ -1025,6 +958,7 @@ class EsmForMaskedLM(EsmPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + @can_return_tuple @auto_docstring def forward( self, @@ -1058,7 +992,7 @@ class EsmForMaskedLM(EsmPreTrainedModel): encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] prediction_scores = self.lm_head(sequence_output) @@ -1070,10 +1004,6 @@ class EsmForMaskedLM(EsmPreTrainedModel): labels = labels.to(prediction_scores.device) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, @@ -1125,6 +1055,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel): self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -1154,7 +1085,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] logits = self.classifier(sequence_output) @@ -1184,10 +1115,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel): loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput( loss=loss, logits=logits, @@ -1210,6 +1137,7 @@ class EsmForTokenClassification(EsmPreTrainedModel): self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -1237,7 +1165,7 @@ class EsmForTokenClassification(EsmPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -1252,10 +1180,6 @@ class EsmForTokenClassification(EsmPreTrainedModel): labels = labels.to(logits.device) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1283,7 +1207,7 @@ class EsmClassificationHead(nn.Module): return x -def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): +def create_position_ids_from_input_ids(input_ids, padding_idx): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. @@ -1295,7 +1219,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. mask = input_ids.ne(padding_idx).int() - incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask return incremental_indices.long() + padding_idx diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index ed826b20d8f..56ba62133f4 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -206,14 +206,22 @@ class FuyuModel(FuyuPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - if image_patches is not None and past_key_values is None: - patch_embeddings = self.get_image_features(image_patches) - patch_embeddings = torch.cat(patch_embeddings, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings) + if image_patches is not None: + patch_embeddings = self.get_image_features(image_patches) + patch_embeddings = torch.cat(patch_embeddings, dim=0) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings) outputs = self.language_model( inputs_embeds=inputs_embeds, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 04b438c5ab4..399c809d126 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -222,7 +222,7 @@ class GemmaAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 51a2ac085be..2d86e9f04c0 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -898,9 +898,11 @@ class Gemma3Model(Gemma3PreTrainedModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index f748461dc46..0b0960a6a98 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -800,9 +800,11 @@ class Gemma3Model(PaliGemmaModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index d29bacf91e2..3a4995610d4 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1135,9 +1135,17 @@ class Gemma3nTextAltUp(nn.Module): corrected += predictions # add the original input return corrected.contiguous().type_as(activated) + def forward(self, corrected: torch.Tensor) -> torch.Tensor: + """ + This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale` + (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in + `scale_corrected_output` + """ + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" - return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + return self.forward(corrected) class Gemma3nTextRotaryEmbedding(nn.Module): @@ -1290,7 +1298,7 @@ class Gemma3nTextAttention(nn.Module): self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers - self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) layer_type = config.layer_types[layer_idx] self.kv_shared_layer_index = ( @@ -1319,21 +1327,22 @@ class Gemma3nTextAttention(nn.Module): query_states = query_states.transpose(1, 2) if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: - # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. + # Device of past layer may be different from current one + indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window - if cache_position.shape[0] > max_length: - # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, - # slice into the entire cache. - indices = slice(0, max_length) - else: - # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1 - indices = cache_position.clamp(min=0, max=max_length - 1) - else: - indices = cache_position + indices = ( + slice(0, max_length) + if cache_position.shape[0] > max_length + else cache_position.clamp(min=0, max=max_length - 1) + ) - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + # Device of past layer may be different from current one + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( + query_states.device + ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) @@ -1447,10 +1456,9 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) - first_prediction = corrected_predictions[self.config.altup_active_idx] - first_prediction_clone = first_prediction.clone() + first_prediction = corrected_predictions[self.config.altup_active_idx].clone() if self.config.altup_correct_scale: - first_prediction = self.altup.scale_corrected_output(first_prediction_clone) + first_prediction = self.altup.scale_corrected_output(first_prediction) # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) first_prediction = self.per_layer_input_gate(first_prediction) @@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): config_class = Gemma3nConfig base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["Gemma3nDecoderLayer"] + _no_split_modules = ["Gemma3nTextDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_3 = True _supports_flash_attn_2 = True @@ -1656,18 +1664,17 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) # Expand hidden_states to support per-layer inputs - target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 - epsilon_tensor = torch.tensor(torch.finfo().min) + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(1e-5) temp_hidden_states = [hidden_states_0] for i in range(1, self.config.altup_num_inputs): # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) - altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) - current_hidden_state = altup_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + altup_proj = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] @@ -1685,9 +1692,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - per_layer_input=per_layer_input, + position_embeddings_global, + position_embeddings_local, + per_layer_input, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1712,11 +1719,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): for i in range(1, self.config.altup_num_inputs): # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) - current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states) @@ -1743,7 +1749,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): per_layer_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) - per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) + per_layer_projection *= self.per_layer_projection_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) per_layer_projection = per_layer_projection.reshape( *inputs_embeds.shape[:-1], self.config.num_hidden_layers, @@ -1758,7 +1766,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] - return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a18ac8c2ef2..b0a5099ff56 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1685,9 +1685,17 @@ class Gemma3nTextAltUp(nn.Module): corrected += predictions # add the original input return corrected.contiguous().type_as(activated) + def forward(self, corrected: torch.Tensor) -> torch.Tensor: + """ + This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale` + (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in + `scale_corrected_output` + """ + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" - return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + return self.forward(corrected) class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding): @@ -1732,7 +1740,7 @@ class Gemma3nTextAttention(Gemma3Attention): self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers - self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) layer_type = config.layer_types[layer_idx] self.kv_shared_layer_index = ( @@ -1761,21 +1769,22 @@ class Gemma3nTextAttention(Gemma3Attention): query_states = query_states.transpose(1, 2) if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: - # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. + # Device of past layer may be different from current one + indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window - if cache_position.shape[0] > max_length: - # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, - # slice into the entire cache. - indices = slice(0, max_length) - else: - # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1 - indices = cache_position.clamp(min=0, max=max_length - 1) - else: - indices = cache_position + indices = ( + slice(0, max_length) + if cache_position.shape[0] > max_length + else cache_position.clamp(min=0, max=max_length - 1) + ) - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + # Device of past layer may be different from current one + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( + query_states.device + ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) @@ -1880,10 +1889,9 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) - first_prediction = corrected_predictions[self.config.altup_active_idx] - first_prediction_clone = first_prediction.clone() + first_prediction = corrected_predictions[self.config.altup_active_idx].clone() if self.config.altup_correct_scale: - first_prediction = self.altup.scale_corrected_output(first_prediction_clone) + first_prediction = self.altup.scale_corrected_output(first_prediction) # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) first_prediction = self.per_layer_input_gate(first_prediction) @@ -1906,7 +1914,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): config_class = Gemma3nConfig base_model_prefix = "" - _no_split_modules = ["Gemma3nDecoderLayer"] + _no_split_modules = ["Gemma3nTextDecoderLayer"] def _init_weights(self, module): # important: this ported version of Gemma2 isn't meant for training from scratch - only @@ -1995,7 +2003,9 @@ class Gemma3nTextModel(Gemma3TextModel): per_layer_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) - per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) + per_layer_projection *= self.per_layer_projection_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) per_layer_projection = per_layer_projection.reshape( *inputs_embeds.shape[:-1], self.config.num_hidden_layers, @@ -2010,7 +2020,9 @@ class Gemma3nTextModel(Gemma3TextModel): # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] - return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) @can_return_tuple @auto_docstring @@ -2091,18 +2103,17 @@ class Gemma3nTextModel(Gemma3TextModel): position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) # Expand hidden_states to support per-layer inputs - target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 - epsilon_tensor = torch.tensor(torch.finfo().min) + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(1e-5) temp_hidden_states = [hidden_states_0] for i in range(1, self.config.altup_num_inputs): # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) - altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) - current_hidden_state = altup_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + altup_proj = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] @@ -2120,9 +2131,9 @@ class Gemma3nTextModel(Gemma3TextModel): layer_outputs = decoder_layer( hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - per_layer_input=per_layer_input, + position_embeddings_global, + position_embeddings_local, + per_layer_input, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -2147,11 +2158,10 @@ class Gemma3nTextModel(Gemma3TextModel): for i in range(1, self.config.altup_num_inputs): # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) - current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 805192cf5a1..a501d03a7c1 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -39,6 +39,7 @@ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and from ...utils import ( ModelOutput, auto_docstring, + can_return_tuple, logging, torch_int, ) @@ -770,6 +771,7 @@ class GitVisionEncoder(nn.Module): self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -836,8 +838,6 @@ class GitVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 86538fc25e5..17deed6bc70 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -184,7 +184,7 @@ class GlmAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 55cc8869d95..ddb15923886 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -242,7 +242,7 @@ class Glm4Attention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 65ec7f0b79c..4148dfd10ac 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -279,14 +279,15 @@ def eager_attention_forward( class Glm4vVisionAttention(nn.Module): def __init__(self, config: Glm4vVisionConfig) -> None: super().__init__() - self.config = config + self.dim = config.hidden_size self.num_heads = config.num_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = 1 - self.scale = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = config.attention_dropout self.is_causal = False def forward( @@ -295,23 +296,31 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[FlashAttentionKwargs], + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] query_states, key_states, value_states = ( self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) - - cos, sin = position_embeddings + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - - attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -322,13 +331,17 @@ class Glm4vVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, + attention_mask=attention_mask, dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scale, - is_causal=self.is_causal, + scaling=self.scaling, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, **kwargs, ) - attn_output = attn_output.squeeze(0) + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output @@ -348,6 +361,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -355,6 +369,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -452,6 +467,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -481,14 +515,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) hidden_states = self.post_layernorm(hidden_states) @@ -1202,50 +1237,59 @@ class Glm4vModel(Glm4vPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = torch.cat(image_embeds, dim=0) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + n_image_tokens = image_mask.sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = torch.cat(video_embeds, dim=0) - n_video_tokens = (input_ids == self.config.image_token_id).sum() + + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (video_mask).sum() n_video_features = video_embeds.shape[0] + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: - attention_mask_tensor = attention_mask + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min @@ -1536,6 +1580,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. @@ -1550,9 +1595,29 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) """ - is_image = input_ids == self.config.image_start_token_id - is_video_start = input_ids == self.config.video_start_token_id - is_video_end = input_ids == self.config.video_end_token_id + if inputs_embeds is not None: + is_image = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_start = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_end = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + is_image = input_ids == self.config.image_start_token_id + is_video_start = input_ids == self.config.video_start_token_id + is_video_end = input_ids == self.config.video_end_token_id # Cumulative sum to track if we're inside a video span # We'll assume well-formed video tags (i.e. matching starts and ends) @@ -1588,7 +1653,9 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) - image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) @@ -1644,10 +1711,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand - # input_ids is required for expanding visual inputs - # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. - if input_ids is not None and input_ids.numel() != 0: - model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index c6ca61b2153..cf4a6b9233f 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -50,8 +50,8 @@ from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLPreTrainedModel, Qwen2_5_VLRotaryEmbedding, Qwen2_5_VLTextModel, + Qwen2_5_VLVisionAttention, Qwen2_5_VLVisionBlock, - apply_rotary_pos_emb_vision, ) from ..qwen2_5_vl.processing_qwen2_5_vl import ( Qwen2_5_VLProcessor, @@ -505,62 +505,12 @@ class Glm4vVisionEmbeddings(nn.Module): return embeddings -class Glm4vVisionAttention(nn.Module): +class Glm4vVisionAttention(Qwen2_5_VLVisionAttention): def __init__(self, config: Glm4vVisionConfig) -> None: super().__init__() - self.config = config - self.num_heads = config.num_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = 1 - self.scale = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.is_causal = False - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - query_states, key_states, value_states = ( - self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - ) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - - query_states = query_states.transpose(0, 1).unsqueeze(0) - key_states = key_states.transpose(0, 1).unsqueeze(0) - value_states = value_states.transpose(0, 1).unsqueeze(0) - - attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scale, - is_causal=self.is_causal, - **kwargs, - ) - attn_output = attn_output.squeeze(0) - attn_output = attn_output.reshape(seq_length, -1).contiguous() - attn_output = self.proj(attn_output) - return attn_output class Glm4vVisionBlock(Qwen2_5_VLVisionBlock): @@ -653,6 +603,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -682,14 +651,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) hidden_states = self.post_layernorm(hidden_states) @@ -1267,50 +1237,59 @@ class Glm4vModel(Qwen2_5_VLModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = torch.cat(image_embeds, dim=0) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + n_image_tokens = image_mask.sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = torch.cat(video_embeds, dim=0) - n_video_tokens = (input_ids == self.config.image_token_id).sum() + + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (video_mask).sum() n_video_features = video_embeds.shape[0] + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: - attention_mask_tensor = attention_mask + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min @@ -1530,6 +1509,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. @@ -1544,9 +1524,29 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) """ - is_image = input_ids == self.config.image_start_token_id - is_video_start = input_ids == self.config.video_start_token_id - is_video_end = input_ids == self.config.video_end_token_id + if inputs_embeds is not None: + is_image = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_start = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + is_video_end = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + is_image = input_ids == self.config.image_start_token_id + is_video_start = input_ids == self.config.video_start_token_id + is_video_end = input_ids == self.config.video_end_token_id # Cumulative sum to track if we're inside a video span # We'll assume well-formed video tags (i.e. matching starts and ends) diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 1b5cd63e7fd..2163dfcea75 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -648,24 +648,27 @@ class GotOcr2Model(GotOcr2PreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 98127283a62..e2b001f2425 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -339,24 +339,27 @@ class GotOcr2Model(LlavaModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index b65530c4061..d5e7e4fe517 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -148,7 +148,7 @@ class GraniteAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 3a48d931ca1..fb8fc6f3be2 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -224,7 +224,7 @@ class HeliumAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 0fab4184bfe..810279c7acf 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_hubert import HubertConfig @@ -300,6 +301,7 @@ class HubertAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -307,7 +309,7 @@ class HubertAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -328,42 +330,9 @@ class HubertAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -385,7 +354,7 @@ class HubertAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class HubertFeedForward(nn.Module): diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index c92bd7ba9c4..8682ff047a8 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( ModelOutput, + can_return_tuple, logging, ) from .configuration_idefics import IdeficsVisionConfig @@ -351,6 +352,7 @@ class IdeficsVisionEncoder(nn.Module): self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -417,8 +419,6 @@ class IdeficsVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 9757a42049f..e18e4ee1379 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -933,10 +933,18 @@ class Idefics2Model(Idefics2PreTrainedModel): - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. """ - special_image_token_mask = input_ids == self.image_token_id - new_inputs_embeds = inputs_embeds.clone() - new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device) - return new_inputs_embeds + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states) + return inputs_embeds def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): """ @@ -1041,25 +1049,8 @@ class Idefics2Model(Idefics2PreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_seen_tokens = 0 - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache: - if not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - past_seen_tokens = past_key_values.get_seq_length() - - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids) @@ -1072,7 +1063,7 @@ class Idefics2Model(Idefics2PreTrainedModel): elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: + if image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( @@ -1094,9 +1085,6 @@ class Idefics2Model(Idefics2PreTrainedModel): **kwargs, ) - if return_legacy_cache and use_cache: - outputs.past_key_values = outputs.past_key_values.to_legacy_cache() - return Idefics2BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, @@ -1304,37 +1292,11 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) **kwargs, ) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires both ids and embeds to be present - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs["input_ids"] = input_ids - - if image_hidden_states is not None: + if image_hidden_states is not None or cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_attention_mask"] = None return model_inputs - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - # Get the precomputed image_hidden_states - model_kwargs["image_hidden_states"] = outputs.image_hidden_states - return model_kwargs - - @staticmethod - # Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - __all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"] diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index a2e0bc78d0f..a0494cb7411 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -663,15 +663,18 @@ class Idefics3Model(Idefics3PreTrainedModel): - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. """ - special_image_token_mask = input_ids == self.image_token_id - # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. - new_inputs_embeds = inputs_embeds.clone() - # Flatten `image_hidden_states` if not flat yet - image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) - # cast to the dtype of the input_embeds to support quantized models + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) - new_inputs_embeds[special_image_token_mask] = image_hidden_states - return new_inputs_embeds + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states) + return inputs_embeds def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): """ @@ -773,11 +776,8 @@ class Idefics3Model(Idefics3PreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_seen_tokens = 0 - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache() - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) @@ -790,7 +790,7 @@ class Idefics3Model(Idefics3PreTrainedModel): elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None: + if image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( @@ -1042,28 +1042,11 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) **kwargs, ) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires both ids and embeds to be present - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs["input_ids"] = input_ids - - if image_hidden_states is not None: + if image_hidden_states is not None or cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_attention_mask"] = None return model_inputs - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - # Get the precomputed image_hidden_states - model_kwargs["image_hidden_states"] = outputs.image_hidden_states - return model_kwargs - __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"] diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index fa064930993..8c75db38d22 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1255,6 +1255,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel): attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -1328,12 +1329,20 @@ class InstructBlipModel(InstructBlipPreTrainedModel): # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) + if inputs_embeds is None: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + special_image_mask = input_ids == self.config.image_token_id + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + else: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1513,6 +1522,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, @@ -1604,15 +1614,26 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "image_token_id", None) is not None: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for image tokens in InstructBLIP should be done in processing. " @@ -1673,6 +1694,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati qformer_attention_mask: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: @@ -1690,6 +1712,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the positional encoding of the image embeddings. @@ -1712,23 +1736,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - if input_ids is None: - start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "image_token_id", None) is not None: - start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) - input_ids = input_ids.repeat(batch_size, 1) + if inputs_embeds is None: + if input_ids is None: + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "image_token_id", None) is not None: + start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "image_token_id", None) is not None: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for image tokens in InstructBLIP should be done in processing. " diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index ea3d19bd3f5..2989e08d091 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1251,6 +1251,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -1334,12 +1335,20 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): # unbatch inputs back, each video-frame gets `num_query_tokens` seq length language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) + if inputs_embeds is None: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + special_image_mask = input_ids == self.config.video_token_id + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + else: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -1485,6 +1494,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, @@ -1599,15 +1609,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "video_token_id", None) is not None: - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " @@ -1668,6 +1689,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel qformer_attention_mask: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: @@ -1685,6 +1707,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the positional encoding of the image embeddings. @@ -1708,23 +1732,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - if input_ids is None: - start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "video_token_id", None) is not None: - start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) - input_ids = input_ids.repeat(batch_size, 1) + if inputs_embeds is None: + if input_ids is None: + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "video_token_id", None) is not None: + start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "video_token_id", None) is not None: - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index e2e6496ed6d..4ec768e4a87 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -202,6 +202,7 @@ class InstructBlipVideoModel(InstructBlipModel): attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -255,12 +256,20 @@ class InstructBlipVideoModel(InstructBlipModel): # unbatch inputs back, each video-frame gets `num_query_tokens` seq length language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) + if inputs_embeds is None: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + special_image_mask = input_ids == self.config.video_token_id + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + else: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: outputs = self.language_model( @@ -372,6 +381,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, @@ -451,15 +461,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if attention_mask is None: attention_mask = torch.ones_like(input_ids) # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "video_token_id", None) is not None: - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " @@ -520,6 +541,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera qformer_attention_mask: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, interpolate_pos_encoding: bool = False, **generate_kwargs, ) -> torch.LongTensor: @@ -537,6 +559,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the positional encoding of the image embeddings. @@ -560,23 +584,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) - if input_ids is None: - start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "video_token_id", None) is not None: - start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) - input_ids = input_ids.repeat(batch_size, 1) + if inputs_embeds is None: + if input_ids is None: + start_tokens = [self.config.text_config.bos_token_id] + if getattr(self.config, "video_token_id", None) is not None: + start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device) + input_ids = input_ids.repeat(batch_size, 1) + inputs_embeds = self.get_input_embeddings()(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) - inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating if getattr(self.config, "video_token_id", None) is not None: - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) - inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: logger.warning_once( "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index e634a281c5c..26f26fae838 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -710,14 +710,14 @@ class InternVLModel(InternVLPreTrainedModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 17fac990805..7ddc11eb7be 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -641,14 +641,14 @@ class InternVLModel(LlavaModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index abdfea032f8..87db9b8a6ec 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1102,23 +1102,21 @@ class JanusModel(JanusPreTrainedModel): ) use_cache = False - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: + if input_ids is None: + image_attention_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_attention_mask = image_attention_mask.all(-1) + else: + image_attention_mask = input_ids == self.config.image_token_id + + image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) - image_attention_mask = input_ids == self.config.image_token_id - - embed_dim = inputs_embeds.shape[-1] - image_features = image_embeds.reshape(-1, embed_dim) - image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) - - image_attention_mask = image_attention_mask.to(inputs_embeds.device) + image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index d485074df39..5fb18d83d72 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -955,23 +955,21 @@ class JanusModel(JanusPreTrainedModel): ) use_cache = False - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: + if input_ids is None: + image_attention_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_attention_mask = image_attention_mask.all(-1) + else: + image_attention_mask = input_ids == self.config.image_token_id + + image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) - image_attention_mask = input_ids == self.config.image_token_id - - embed_dim = inputs_embeds.shape[-1] - image_features = image_embeds.reshape(-1, embed_dim) - image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim) - - image_attention_mask = image_attention_mask.to(inputs_embeds.device) + image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 0926d17b318..0ca82e4b771 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -451,6 +451,7 @@ class Kosmos2VisionEncoder(nn.Module): self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -517,8 +518,6 @@ class Kosmos2VisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) @@ -1468,25 +1467,19 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): image_embeds_position_mask=None, past_key_values=None, attention_mask=None, + inputs_embeds=None, use_cache=None, cache_position=None, **model_kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - # Kosmos2 has offset for position ids, so we need to create them correctly - position_ids = create_position_ids_from_input_ids( - input_ids, - padding_idx=self.config.pad_token_id, - past_key_values_length=0, - ) - if past_key_values is not None: image_embeds = None image_embeds_position_mask = None # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) elif image_embeds_position_mask is not None: - batch_size, seq_len = input_ids.size() + batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size() mask_len = image_embeds_position_mask.size()[-1] image_embeds_position_mask = torch.cat( ( @@ -1502,11 +1495,13 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, + inputs_embeds=inputs_embeds, use_cache=use_cache, - position_ids=position_ids, cache_position=cache_position, **model_kwargs, ) + # Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer + model_inputs.pop("position_ids", None) return model_inputs @@ -1876,6 +1871,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): # in order to allow `inputs` argument (as in `GenerationMixin`) @@ -1901,6 +1897,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): attention_mask=attention_mask, image_embeds=image_embeds, image_embeds_position_mask=image_embeds_position_mask, + inputs_embeds=inputs_embeds, **kwargs, ) diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py index 1a440cf55e8..95bc2eda6fa 100644 --- a/src/transformers/models/layoutlm/configuration_layoutlm.py +++ b/src/transformers/models/layoutlm/configuration_layoutlm.py @@ -14,6 +14,7 @@ # limitations under the License. """LayoutLM model configuration""" +import warnings from collections import OrderedDict from collections.abc import Mapping from typing import Any, Optional @@ -130,10 +131,22 @@ class LayoutLMConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type + self._position_embedding_type = position_embedding_type self.use_cache = use_cache self.max_2d_position_embeddings = max_2d_position_embeddings + @property + def position_embedding_type(self): + warnings.warn( + "The `position_embedding_type` attribute is deprecated and will be removed in v4.55.", + FutureWarning, + ) + return self._position_embedding_type + + @position_embedding_type.setter + def position_embedding_type(self, value): + self._position_embedding_type = value + class LayoutLMOnnxConfig(OnnxConfig): def __init__( diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 87dfed1a8c3..6fd8fcc8078 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -14,8 +14,7 @@ # limitations under the License. """PyTorch LayoutLM model.""" -import math -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.utils.checkpoint @@ -25,16 +24,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, + BaseModelOutput, + BaseModelOutputWithPooling, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_layoutlm import LayoutLMConfig @@ -120,9 +120,37 @@ class LayoutLMEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM class LayoutLMSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -130,6 +158,7 @@ class LayoutLMSelfAttention(nn.Module): f"heads ({config.num_attention_heads})" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -139,20 +168,12 @@ class LayoutLMSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -162,96 +183,33 @@ class LayoutLMSelfAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_layer = self.transpose_for_scores(mixed_query_layer) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -270,18 +228,11 @@ class LayoutLMSelfOutput(nn.Module): return hidden_states -LAYOUTLM_SELF_ATTENTION_CLASSES = { - "eager": LayoutLMSelfAttention, -} - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM +# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM class LayoutLMAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = LayoutLMSelfAttention(config) self.output = LayoutLMSelfOutput(config) self.pruned_heads = set() @@ -303,6 +254,9 @@ class LayoutLMAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -312,15 +266,14 @@ class LayoutLMAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -358,22 +311,19 @@ class LayoutLMOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM +# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM class LayoutLMLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = LayoutLMAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute") self.intermediate = LayoutLMIntermediate(config) self.output = LayoutLMOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -383,60 +333,23 @@ class LayoutLMLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -445,14 +358,19 @@ class LayoutLMLayer(GradientCheckpointingLayer): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM class LayoutLMEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -465,65 +383,36 @@ class LayoutLMEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -648,6 +537,9 @@ class LayoutLMModel(LayoutLMPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -663,7 +555,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + ) -> Union[tuple, BaseModelOutputWithPooling]: r""" bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*): Bounding boxes of each input sequence tokens. Selected in the range `[0, @@ -756,20 +648,16 @@ class LayoutLMModel(LayoutLMPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) @@ -796,6 +684,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): self.cls.predictions.decoder = new_embeddings self.cls.predictions.bias = new_embeddings.bias + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -871,11 +762,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -889,10 +778,6 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): labels.view(-1), ) - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, @@ -921,6 +806,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel): def get_input_embeddings(self): return self.layoutlm.embeddings.word_embeddings + @can_return_tuple @auto_docstring def forward( self, @@ -996,7 +882,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) pooled_output = outputs[1] @@ -1026,9 +912,6 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel): elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, @@ -1059,6 +942,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): def get_input_embeddings(self): return self.layoutlm.embeddings.word_embeddings + @can_return_tuple @auto_docstring def forward( self, @@ -1132,7 +1016,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -1145,10 +1029,6 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -1176,6 +1056,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel): def get_input_embeddings(self): return self.layoutlm.embeddings.word_embeddings + @can_return_tuple @auto_docstring def forward( self, @@ -1253,7 +1134,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -1280,10 +1161,6 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel): end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e79a7697602..74c9651aca0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -224,7 +224,7 @@ class LlamaAttention(nn.Module): past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 53d2c7fc9b2..259482934f3 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -1358,27 +1358,28 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): vision_feature_select_strategy=vision_feature_select_strategy, image_sizes=image_sizes, ) - original_inputs_embeds_shape = inputs_embeds.shape vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - final_mask = special_image_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id - final_mask_1d = final_mask[..., 0].reshape(-1) - num_tokens_to_fill = final_mask_1d.sum() + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - if num_tokens_to_fill != projected_vision_flat.size(0): + if n_image_tokens != projected_vision_flat.size(0): raise ValueError( - f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + f"Mismatch: final_mask wants {n_image_tokens} embeddings, " f"but multi_modal_projector returned {projected_vision_flat.size(0)}" ) - - expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) - inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat) - inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + projected_vision_flat = projected_vision_flat.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index a7296538271..346dc98f2bf 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -284,14 +284,14 @@ class LlavaModel(LlavaPreTrainedModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 08dfb36d47e..14ad299a73f 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -468,11 +468,6 @@ class LlavaNextModel(LlavaNextPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -485,10 +480,18 @@ class LlavaNextModel(LlavaNextPreTrainedModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 5a01c350210..dbd6ceaed1f 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -519,12 +519,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " - "and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -537,10 +531,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -559,10 +561,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + n_video_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index badaa23a71b..0a328593499 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -440,12 +440,6 @@ class LlavaNextVideoModel(LlavaNextModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " - "and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -458,10 +452,18 @@ class LlavaNextVideoModel(LlavaNextModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -480,10 +482,18 @@ class LlavaNextVideoModel(LlavaNextModel): video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + n_video_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index a06a9750baa..616a314e3be 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -551,12 +551,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " - "and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -571,10 +565,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -595,10 +597,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (special_video_mask).sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index 4920124522b..2461c89b72a 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -535,12 +535,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " - "and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -555,10 +549,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -579,10 +581,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel): video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (special_video_mask).sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/markuplm/configuration_markuplm.py b/src/transformers/models/markuplm/configuration_markuplm.py index f8bee878e83..e5945cb3307 100644 --- a/src/transformers/models/markuplm/configuration_markuplm.py +++ b/src/transformers/models/markuplm/configuration_markuplm.py @@ -14,6 +14,8 @@ # limitations under the License. """MarkupLM model configuration""" +import warnings + from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -141,7 +143,7 @@ class MarkupLMConfig(PretrainedConfig): self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type + self._position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout # additional properties @@ -152,5 +154,17 @@ class MarkupLMConfig(PretrainedConfig): self.subs_pad_id = subs_pad_id self.xpath_unit_hidden_size = xpath_unit_hidden_size + @property + def position_embedding_type(self): + warnings.warn( + "The `position_embedding_type` attribute is deprecated and will be removed in v4.55.", + FutureWarning, + ) + return self._position_embedding_type + + @position_embedding_type.setter + def position_embedding_type(self, value): + self._position_embedding_type = value + __all__ = ["MarkupLMConfig"] diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 4a34c85b3db..41dba3a2563 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -14,9 +14,8 @@ # limitations under the License. """PyTorch MarkupLM model.""" -import math import os -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.utils.checkpoint @@ -26,20 +25,22 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, + BaseModelOutput, + BaseModelOutputWithPooling, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from ...modeling_utils import ( + ALL_ATTENTION_FUNCTIONS, PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_markuplm import MarkupLMConfig @@ -326,9 +327,37 @@ class MarkupLMOnlyMLMHead(nn.Module): return prediction_scores -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM class MarkupLMSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -336,6 +365,7 @@ class MarkupLMSelfAttention(nn.Module): f"heads ({config.num_attention_heads})" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -345,20 +375,12 @@ class MarkupLMSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -368,111 +390,41 @@ class MarkupLMSelfAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_layer = self.transpose_for_scores(mixed_query_layer) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in MarkupLMModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs -MARKUPLM_SELF_ATTENTION_CLASSES = { - "eager": MarkupLMSelfAttention, -} - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM +# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM class MarkupLMAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = MarkupLMSelfAttention(config) self.output = MarkupLMSelfOutput(config) self.pruned_heads = set() @@ -494,6 +446,9 @@ class MarkupLMAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -503,37 +458,33 @@ class MarkupLMAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM +# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM class MarkupLMLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = MarkupLMAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute") self.intermediate = MarkupLMIntermediate(config) self.output = MarkupLMOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -543,60 +494,23 @@ class MarkupLMLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -605,14 +519,19 @@ class MarkupLMLayer(GradientCheckpointingLayer): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM class MarkupLMEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -625,65 +544,36 @@ class MarkupLMEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -749,6 +639,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @can_return_tuple @auto_docstring def forward( self, @@ -763,7 +654,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + ) -> Union[tuple, BaseModelOutputWithPooling]: r""" xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*): Tag IDs for each token in the input sequence, padded up to config.max_depth. @@ -839,21 +730,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel): head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) # Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache @@ -879,6 +765,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -939,7 +826,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -966,10 +853,6 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, @@ -1000,6 +883,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -1058,7 +942,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = outputs[0] @@ -1072,10 +956,6 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel): labels.view(-1), ) - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=prediction_scores, @@ -1107,6 +987,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -1164,7 +1045,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel): inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) pooled_output = outputs[1] @@ -1194,9 +1075,6 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel): elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 935ffcb67bf..6cfaf8d92e7 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -354,11 +354,6 @@ class MaskFormerSwinSelfAttention(nn.Module): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -367,11 +362,11 @@ class MaskFormerSwinSelfAttention(nn.Module): output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) + hidden_shape = (batch_size, dim, -1, self.attention_head_size) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 9b885c3c385..63b8a0b0b25 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -308,11 +308,6 @@ class Mistral3Model(Mistral3PreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -324,10 +319,18 @@ class Mistral3Model(Mistral3PreTrainedModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 391b02d42b8..2027d323a51 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -204,11 +204,6 @@ class Mistral3Model(LlavaModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -220,10 +215,18 @@ class Mistral3Model(LlavaModel): ) image_features = torch.cat(image_features, dim=0) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index bc217736551..11765cf3380 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -182,7 +182,6 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen class MusicgenAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index b3a2322e4aa..5cdbcd7a696 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -189,7 +189,7 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody class MusicgenMelodyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index f498cf743fc..48f184c078b 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -503,7 +503,7 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states class NllbMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 6aabb3a3d80..36183c66b66 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -331,9 +331,11 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) + special_image_mask = special_image_mask.all(-1) else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 857e4eb320d..1ff46faf6ef 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_patchtsmixer import PatchTSMixerConfig @@ -303,6 +304,7 @@ class PatchTSMixerAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -310,7 +312,7 @@ class PatchTSMixerAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -331,42 +333,9 @@ class PatchTSMixerAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -388,7 +357,7 @@ class PatchTSMixerAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class PatchMixerBlock(nn.Module): diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index ec8349dfd6f..dfd28ea2b0a 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_patchtst import PatchTSTConfig @@ -100,6 +101,7 @@ class PatchTSTAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -107,7 +109,7 @@ class PatchTSTAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -128,42 +130,9 @@ class PatchTSTAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -185,7 +154,7 @@ class PatchTSTAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class PatchTSTBatchNorm(nn.Module): diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c63beb73fac..37a290e8001 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -607,6 +607,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 self.is_decoder = False self.is_causal = False @@ -619,6 +620,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -634,15 +636,6 @@ class Qwen2_5OmniAudioAttention(nn.Module): value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attention_mask = torch.full( - [1, 1, seq_length, key_states.shape[-2]], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -652,13 +645,13 @@ class Qwen2_5OmniAudioAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -686,6 +679,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -704,6 +698,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, + attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -785,6 +780,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): def set_input_embeddings(self, value: nn.Module): self.conv1 = value + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + @auto_docstring def forward( self, @@ -833,9 +847,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): padded_mask_after_cnn.sum(1).cumsum(0), ) ).to(torch.int32) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for encoder_layer in self.layers: - layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -928,12 +948,15 @@ class Qwen2_5OmniVisionAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.num_key_value_groups = 1 # needed for eager attention self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -943,18 +966,9 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -966,13 +980,13 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -1009,10 +1023,15 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -1171,6 +1190,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): return window_index, cu_window_seqlens + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -1217,10 +1255,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) @@ -1862,43 +1903,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text , audios , image and video - if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage - if input_features is not None: - audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + if input_ids is None: + audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) ) - audio_mask = ( - (input_ids == self.config.audio_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + audio_mask = audio_mask.all(-1) + else: + audio_mask = input_ids == self.config.audio_token_id - if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) - if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 9acc76c9afa..645a5fb837f 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1611,6 +1611,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 self.is_decoder = False self.is_causal = False @@ -1623,6 +1624,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1638,15 +1640,6 @@ class Qwen2_5OmniAudioAttention(nn.Module): value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attention_mask = torch.full( - [1, 1, seq_length, key_states.shape[-2]], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -1656,13 +1649,13 @@ class Qwen2_5OmniAudioAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -1682,6 +1675,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -1689,6 +1683,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, + attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -1770,6 +1765,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): def set_input_embeddings(self, value: nn.Module): self.conv1 = value + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + @auto_docstring def forward( self, @@ -1818,9 +1832,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): padded_mask_after_cnn.sum(1).cumsum(0), ) ).to(torch.int32) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for encoder_layer in self.layers: - layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -1906,12 +1926,15 @@ class Qwen2_5OmniVisionAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.num_key_value_groups = 1 # needed for eager attention self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -1921,18 +1944,9 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -1944,13 +1958,13 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -1970,10 +1984,15 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -1987,6 +2006,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): super().__init__(config, *inputs, **kwargs) self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -2033,10 +2071,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) @@ -2309,43 +2350,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text , audios , image and video - if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage - if input_features is not None: - audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + if input_ids is None: + audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) ) - audio_mask = ( - (input_ids == self.config.audio_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + audio_mask = audio_mask.all(-1) + else: + audio_mask = input_ids == self.config.audio_token_id - if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) - if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ab318d955ff..1b69973cc3b 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -206,6 +206,8 @@ class Qwen2_5_VLVisionAttention(nn.Module): self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, @@ -213,6 +215,7 @@ class Qwen2_5_VLVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -233,18 +236,9 @@ class Qwen2_5_VLVisionAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(value_states.dtype).min, - device=value_states.device, - dtype=value_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -256,13 +250,13 @@ class Qwen2_5_VLVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -286,6 +280,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -293,6 +288,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -426,6 +422,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -472,8 +487,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, ) hidden_states = self.merger(hidden_states) @@ -1224,41 +1245,51 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id - if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - n_video_tokens = (input_ids == self.config.video_token_id).sum() - n_video_features = video_embeds.shape[0] - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) + n_image_tokens = (image_mask).sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_embeds.shape[0] + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (video_mask).sum() + n_video_features = video_embeds.shape[0] + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: attention_mask_tensor = ( @@ -1565,6 +1596,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. @@ -1582,10 +1614,31 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id - vision_start_mask = input_ids == vision_start_token_id + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) - image_mask = input_ids == image_token_id - video_mask = input_ids == video_token_id image_nums = torch.sum(vision_first_mask & image_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1) @@ -1611,7 +1664,9 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) - image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) @@ -1667,10 +1722,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand - # input_ids is required for expanding visual inputs - # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. - if input_ids is not None and input_ids.numel() != 0: - model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 84a7a69ac81..2686194e09e 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -159,6 +159,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -166,6 +167,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -287,6 +289,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -333,8 +354,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, ) hidden_states = self.merger(hidden_states) @@ -582,41 +609,51 @@ class Qwen2_5_VLModel(Qwen2VLModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - n_image_tokens = (input_ids == self.config.image_token_id).sum() - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id - if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - n_video_tokens = (input_ids == self.config.video_token_id).sum() - n_video_features = video_embeds.shape[0] - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) + n_image_tokens = (image_mask).sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_embeds.shape[0] + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + video_mask = video_mask.all(-1) + else: + video_mask = input_ids == self.config.video_token_id + + n_video_tokens = (video_mask).sum() + n_video_features = video_embeds.shape[0] + video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: attention_mask_tensor = ( diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index a799e7328e5..311622b2222 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -324,6 +324,8 @@ class VisionAttention(nn.Module): self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, @@ -331,6 +333,7 @@ class VisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -351,18 +354,9 @@ class VisionAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(value_states.dtype).min, - device=value_states.device, - dtype=value_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -374,13 +368,13 @@ class VisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -406,6 +400,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -413,6 +408,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -725,6 +721,25 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + @auto_docstring def forward( self, @@ -750,10 +765,15 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for blk in self.blocks: hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, ) return self.merger(hidden_states) @@ -1162,41 +1182,52 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) + + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + image_mask = image_mask.all(-1) + else: + image_mask = input_ids == self.config.image_token_id + + n_image_tokens = image_mask.sum() + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_embeds.shape[0] + if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) + + if input_ids is None: + video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_video_tokens = (video_mask).sum(dim=1).sum(dim=0)[0] + else: + video_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + video_mask = video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + n_video_tokens = (input_ids == self.config.image_token_id).sum() + + n_video_features = video_embeds.shape[0] + if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: attention_mask_tensor = ( @@ -1460,6 +1491,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. @@ -1477,10 +1509,31 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id - vision_start_mask = input_ids == vision_start_token_id + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) - image_mask = input_ids == image_token_id - video_mask = input_ids == video_token_id image_nums = torch.sum(vision_first_mask & image_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1) @@ -1506,7 +1559,9 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) - image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) @@ -1562,10 +1617,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand - # input_ids is required for expanding visual inputs - # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. - if input_ids is not None and input_ids.numel() != 0: - model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index da4a54b39fc..5ca359f0b02 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_sew import SEWConfig @@ -293,6 +294,7 @@ class SEWAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -300,7 +302,7 @@ class SEWAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -321,42 +323,9 @@ class SEWAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -378,7 +347,7 @@ class SEWAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class SEWFeedForward(nn.Module): diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 1b128a0fb63..155f8110a54 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -595,7 +595,14 @@ class SmolVLMModel(SmolVLMPreTrainedModel): """ _, patch_size, _ = image_hidden_states.shape - image_mask = input_ids == self.image_token_id + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask[..., 0] # slice off the hidden dim + else: + image_mask = input_ids == self.config.image_token_id + num_image_tokens = image_mask.sum(dim=1) if not torch.all(num_image_tokens % patch_size == 0): raise ValueError("At least one sample has tokens not divisible by patch_size.") @@ -717,14 +724,8 @@ class SmolVLMModel(SmolVLMPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_seen_tokens = 0 - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache() - past_seen_tokens = past_key_values.get_seq_length() - - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) @@ -732,12 +733,13 @@ class SmolVLMModel(SmolVLMPreTrainedModel): # START VISUAL INPUTS INTEGRATION if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") - elif pixel_values is not None: - image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device) - elif image_hidden_states is not None: - image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - if inputs_embeds is not None and image_hidden_states is not None: + if pixel_values is not None: + image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device) + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device) + + if image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( @@ -996,27 +998,11 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): **kwargs, ) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - # but IDEFICS requires both ids and embeds to be present - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs["input_ids"] = input_ids - - if image_hidden_states is not None: + if image_hidden_states is not None or cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_attention_mask"] = None return model_inputs - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - # Get the precomputed image_hidden_states - model_kwargs["image_hidden_states"] = outputs.image_hidden_states - return model_kwargs - __all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"] diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index d4fffa0c40f..e20f3d7b6bf 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -180,7 +180,14 @@ class SmolVLMModel(Idefics3Model): ): _, patch_size, _ = image_hidden_states.shape - image_mask = input_ids == self.image_token_id + if input_ids is None: + image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + image_mask = image_mask[..., 0] # slice off the hidden dim + else: + image_mask = input_ids == self.config.image_token_id + num_image_tokens = image_mask.sum(dim=1) if not torch.all(num_image_tokens % patch_size == 0): raise ValueError("At least one sample has tokens not divisible by patch_size.") @@ -296,14 +303,8 @@ class SmolVLMModel(Idefics3Model): else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_seen_tokens = 0 - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache() - past_seen_tokens = past_key_values.get_seq_length() - - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) @@ -311,12 +312,13 @@ class SmolVLMModel(Idefics3Model): # START VISUAL INPUTS INTEGRATION if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") - elif pixel_values is not None: - image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device) - elif image_hidden_states is not None: - image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - if inputs_embeds is not None and image_hidden_states is not None: + if pixel_values is not None: + image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device) + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device) + + if image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index aaff8d90fec..73e3df2b4a9 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -205,7 +205,7 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text +# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->Speech2Text class Speech2TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 259272f445e..3b4e8f56002 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -14,9 +14,8 @@ # limitations under the License. """PyTorch Splinter model.""" -import math from dataclasses import dataclass -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.utils.checkpoint @@ -25,13 +24,19 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_outputs import ( + BaseModelOutput, + ModelOutput, + QuestionAnsweringModelOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( auto_docstring, + can_return_tuple, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_splinter import SplinterConfig @@ -64,7 +69,6 @@ class SplinterEmbeddings(nn.Module): token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values_length: Optional[int] = 0, ) -> tuple: if input_ids is not None: input_shape = input_ids.size() @@ -74,7 +78,7 @@ class SplinterEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = self.position_ids[:, :seq_length] if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) @@ -92,9 +96,37 @@ class SplinterEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter class SplinterSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -102,6 +134,7 @@ class SplinterSelfAttention(nn.Module): f"heads ({config.num_attention_heads})" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -111,20 +144,12 @@ class SplinterSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - - self.is_decoder = config.is_decoder - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -134,96 +159,33 @@ class SplinterSelfAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_layer = self.transpose_for_scores(mixed_query_layer) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in SplinterModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -242,18 +204,11 @@ class SplinterSelfOutput(nn.Module): return hidden_states -SPLINTER_SELF_ATTENTION_CLASSES = { - "eager": SplinterSelfAttention, -} - - -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER +# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter class SplinterAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = SplinterSelfAttention(config) self.output = SplinterSelfOutput(config) self.pruned_heads = set() @@ -275,6 +230,9 @@ class SplinterAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -284,15 +242,14 @@ class SplinterAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -330,22 +287,19 @@ class SplinterOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter +# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter class SplinterLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = SplinterAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if not self.is_decoder: - raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = SplinterAttention(config, position_embedding_type="absolute") self.intermediate = SplinterIntermediate(config) self.output = SplinterOutput(config) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -355,60 +309,23 @@ class SplinterLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, + **kwargs, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + **kwargs, ) attention_output = self_attention_outputs[0] - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None - if self.is_decoder and encoder_hidden_states is not None: - if not hasattr(self, "crossattention"): - raise ValueError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) - - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value - + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -417,14 +334,19 @@ class SplinterLayer(GradientCheckpointingLayer): return layer_output -# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter class SplinterEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple def forward( self, hidden_states: torch.Tensor, @@ -437,65 +359,36 @@ class SplinterEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, - ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, # as a positional argument for gradient checkpointing - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, + **kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, - cross_attentions=all_cross_attentions, ) @@ -554,6 +447,11 @@ class SplinterModel(SplinterPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) + @deprecate_kwarg("encoder_hidden_states", version="4.54.0") + @deprecate_kwarg("encoder_attention_mask", version="4.54.0") + @deprecate_kwarg("past_key_values", version="4.54.0") + @deprecate_kwarg("use_cache", version="4.54.0") + @can_return_tuple @auto_docstring def forward( self, @@ -570,7 +468,7 @@ class SplinterModel(SplinterPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + ) -> Union[tuple, BaseModelOutput]: r""" token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, @@ -592,11 +490,6 @@ class SplinterModel(SplinterPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -610,11 +503,8 @@ class SplinterModel(SplinterPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + attention_mask = torch.ones(((batch_size, seq_length)), device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) @@ -622,17 +512,6 @@ class SplinterModel(SplinterPreTrainedModel): # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -645,31 +524,21 @@ class SplinterModel(SplinterPreTrainedModel): position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + return_dict=True, ) sequence_output = encoder_outputs[0] - if not return_dict: - return (sequence_output,) + encoder_outputs[1:] - - return BaseModelOutputWithPastAndCrossAttentions( + return BaseModelOutput( last_hidden_state=sequence_output, - past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, ) diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 7bcf8d98251..33e50de7aa8 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -725,8 +725,8 @@ class SuperGlueForKeypointMatching(SuperGluePreTrainedModel): matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) - matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1) - matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1) + matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1) + matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1) if output_hidden_states: all_hidden_states = all_hidden_states + encoded_keypoints[1] diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 7ea56890b58..5bd79aec335 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -435,11 +435,6 @@ class SwinSelfAttention(nn.Module): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - def forward( self, hidden_states: torch.Tensor, @@ -448,11 +443,11 @@ class SwinSelfAttention(nn.Module): output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: batch_size, dim, num_channels = hidden_states.shape - mixed_query_layer = self.query(hidden_states) + hidden_shape = (batch_size, dim, -1, self.attention_head_size) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 373b25b4e1a..e8a43c28261 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -45,6 +45,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_unispeech import UniSpeechConfig @@ -332,6 +333,7 @@ class UniSpeechAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -339,7 +341,7 @@ class UniSpeechAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -360,42 +362,9 @@ class UniSpeechAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -417,7 +386,7 @@ class UniSpeechAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class UniSpeechFeedForward(nn.Module): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 0ce8a7c8154..0e2140aee85 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -47,6 +47,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_unispeech_sat import UniSpeechSatConfig @@ -337,6 +338,7 @@ class UniSpeechSatAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -344,7 +346,7 @@ class UniSpeechSatAttention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -365,42 +367,9 @@ class UniSpeechSatAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -422,7 +391,7 @@ class UniSpeechSatAttention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class UniSpeechSatFeedForward(nn.Module): diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 7402bfbacd7..d5171e8831e 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -328,12 +328,6 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None: - raise ValueError( - "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same " - "time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -343,10 +337,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -359,10 +361,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + n_video_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_features.shape[0] * video_features.shape[1] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 88d87a31bca..7e668bf978d 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -233,11 +233,6 @@ class VipLlavaModel(VipLlavaPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -246,10 +241,18 @@ class VipLlavaModel(VipLlavaPreTrainedModel): pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/vipllava/modular_vipllava.py b/src/transformers/models/vipllava/modular_vipllava.py index 0c52b013639..97458112ad5 100644 --- a/src/transformers/models/vipllava/modular_vipllava.py +++ b/src/transformers/models/vipllava/modular_vipllava.py @@ -136,11 +136,6 @@ class VipLlavaModel(LlavaModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -149,10 +144,18 @@ class VipLlavaModel(LlavaModel): pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = (special_image_mask).sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index c7d04dab28f..be43995e97d 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -55,6 +55,7 @@ from ...utils import ( is_torch_flex_attn_available, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_wav2vec2 import Wav2Vec2Config @@ -524,6 +525,7 @@ class Wav2Vec2Attention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", version="4.54.0") def forward( self, hidden_states: torch.Tensor, @@ -531,7 +533,7 @@ class Wav2Vec2Attention(nn.Module): past_key_value: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + output_attentions: Optional[bool] = False, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -552,42 +554,9 @@ class Wav2Vec2Attention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + current_states = key_value_states if is_cross_attention else hidden_states + key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -609,7 +578,7 @@ class Wav2Vec2Attention(nn.Module): attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, None class Wav2Vec2FeedForward(nn.Module): diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 0e043f354ee..f33082d2612 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -30,6 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( ModelOutput, auto_docstring, + can_return_tuple, logging, torch_int, ) @@ -576,6 +577,7 @@ class XCLIPEncoder(nn.Module): self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False + @can_return_tuple def forward( self, inputs_embeds, @@ -642,8 +644,6 @@ class XCLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 3b9cb2c5201..78349b8b906 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1642,7 +1642,6 @@ def set_model_tester_for_less_flaky_test(test_case): "AriaVisionText2TextModelTester", "GPTNeoModelTester", "DPTModelTester", - "Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester ] if test_case.model_tester.__class__.__name__ in exceptional_classes: target_num_hidden_layers = None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 746fd2179d2..c9e7f9f4f13 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -118,27 +118,6 @@ from unittest.mock import patch from transformers.utils import is_sklearn_available -# TODO: raushan remove this when VLMs start accepting input embeds -VLM_CLASS_NAMES = [ - "llava", - "idefics2", - "idefics3", - "mllama", - "paligemma", - "emu3", - "gotocr2", - "qwen2vl", - "qwen2_5_vl", - "ayavision", - "janus", - "gemma3", - "mistral3", - "chameleon", - "internvl", - "qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni` -] - - class GenerationTesterMixin: input_name = "input_ids" model_tester = None @@ -1228,7 +1207,23 @@ class GenerationTesterMixin: "blip2", # overridden `generate()` "instructblip", "instructblipvideo", - *VLM_CLASS_NAMES, # shouldn't suggest image tokens + # All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan + "llava", + "idefics2", + "idefics3", + "mllama", + "paligemma", + "emu3", + "gotocr2", + "qwen2vl", + "qwen2_5_vl", + "ayavision", + "janus", + "gemma3", + "mistral3", + "chameleon", + "internvl", + "qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`, ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -1641,6 +1636,58 @@ class GenerationTesterMixin: self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + @pytest.mark.generate + def test_generate_from_random_inputs_embeds(self): + """ + Text-only: Tests that different `inputs_embeds` generate different outputs in models with `main_input=="input_ids"`. + Some models have 'images' as main input and thus can't generate with random text embeddings. + See `test_generate_from_inputs_embeds` for more general checks. + """ + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + if config.is_encoder_decoder: + continue + config.is_decoder = True + + model = model_class(config).to(torch_device).eval() + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + continue + + # No easy fix, let's skip the test for now + has_complex_embeds_computation = any( + model_name in model_class.__name__.lower() for model_name in ["moshi"] + ) + + if model_class.main_input_name != "input_ids" or has_complex_embeds_computation: + self.skipTest( + "The model's main input name in not `input_ids` and we need kwargs from input dict as well." + ) + + if hasattr(config, "scale_embedding"): + config.scale_embedding = False + + generation_kwargs = { + "return_dict_in_generate": True, + "output_scores": True, + "do_sample": False, + "max_new_tokens": 5, + "min_new_tokens": 5, # generate exactly 5 tokens + } + + input_ids = inputs_dict.pop("input_ids") + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, **generation_kwargs) + + # If we pass different inputs_embeds, we should get different outputs (the output text may be the + # same, but the logits will almost surely be different) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate( + input_ids=input_ids, inputs_embeds=random_embeds, **generation_kwargs + ) + for i in range(len(outputs_from_rand_embeds.scores)): + self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) + @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) def test_generate_from_inputs_embeds(self, _, num_beams): @@ -1662,34 +1709,22 @@ class GenerationTesterMixin: continue # There are a few exception patterns in this test: - # 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed - requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"]) - # 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex + # 1 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex # than calling the embedding layer with `input_ids`. Subcases of this exception: - # 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag) + # 1.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag) if hasattr(config, "scale_embedding"): config.scale_embedding = False - # 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the - # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the - # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` - pixel_values_is_mutually_exclusive = any( - model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES - ) - if pixel_values_is_mutually_exclusive: - inputs_dict.pop("pixel_values", None) - inputs_dict.pop("pixel_values_videos", None) - inputs_dict.pop("pixel_values_images", None) # HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive; # this is similar to VLMs and should likely be standardized for similar audio models in the future, # then made generic here. if "granitespeech" in model_class.__name__.lower(): inputs_dict.pop("input_features", None) - # 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` + # 1.B - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` has_complex_embeds_computation = any( model_name in model_class.__name__.lower() for model_name in ["moshi"] ) - # 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, + # 2 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input. missing_attention_mask = "attention_mask" not in inputs_dict @@ -1702,31 +1737,23 @@ class GenerationTesterMixin: "do_sample": False, "max_new_tokens": 5, "min_new_tokens": 5, # generate exactly 5 tokens + "use_cache": True, } - outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict) + outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict) self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output). # The output of the two calls should be the same. inputs_embeds = model.get_input_embeddings()(input_ids) outputs_from_embeds = model.generate( - input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict + input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) if not has_complex_embeds_computation: self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) - # If we pass different inputs_embeds, we should get different outputs (the output text may be the - # same, but the logits will almost surely be different) - random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate( - input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict - ) - for i in range(len(outputs_from_rand_embeds.scores)): - self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) - # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will # be the same - if not (requires_inputs_ids or missing_attention_mask): + if not missing_attention_mask: outputs_from_embeds_wo_ids = model.generate( inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) @@ -1753,17 +1780,6 @@ class GenerationTesterMixin: if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): self.skipTest(reason="This model does not support `inputs_embeds` in generation") - # Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the - # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the - # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` - pixel_values_is_mutually_exclusive = any( - model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES - ) - if pixel_values_is_mutually_exclusive: - inputs_dict.pop("pixel_values", None) - inputs_dict.pop("pixel_values_videos", None) - inputs_dict.pop("pixel_values_images", None) - input_ids = inputs_dict.pop("input_ids") model.config.use_cache = True @@ -1925,14 +1941,6 @@ class GenerationTesterMixin: if "past_key_values" not in outputs: self.skipTest(reason="This model doesn't return `past_key_values`") - pixel_values_is_mutually_exclusive = any( - model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES - ) - if pixel_values_is_mutually_exclusive: - inputs_dict.pop("pixel_values", None) - inputs_dict.pop("pixel_values_videos", None) - inputs_dict.pop("pixel_values_images", None) - input_ids = inputs_dict.pop("input_ids") model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 0ad5e5cb03d..d56f6326acf 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -297,7 +297,7 @@ class AltCLIPTextModelTester: @require_torch class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (AltCLIPTextModel,) if is_torch_available() else () - fx_compatible = True + fx_compatible = False # Cannot support if `can_return_tuple` test_pruning = False test_head_masking = False @@ -411,7 +411,7 @@ def prepare_img(): class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (AltCLIPModel,) if is_torch_available() else () pipeline_model_mapping = {"feature-extraction": AltCLIPModel} if is_torch_available() else {} - fx_compatible = True + fx_compatible = False # Cannot support if `can_return_tuple` test_head_masking = False test_pruning = False test_resize_embeddings = False diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index 747963aa50e..ece423338a6 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -189,49 +189,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi self.model_tester = AriaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip( reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) @@ -270,14 +227,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi def test_dola_decoding_sample(self): pass - @unittest.skip(reason="Unsupported") - def test_generate_from_inputs_embeds_0_greedy(self): - pass - - @unittest.skip(reason="Unsupported") - def test_generate_from_inputs_embeds_1_beam_search(self): - pass - @unittest.skip(reason="Dynamic control flow due to MoE") def test_generate_with_static_cache(self): pass diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py index 5cde1f216ec..d472e0eb90f 100644 --- a/tests/models/aya_vision/test_modeling_aya_vision.py +++ b/tests/models/aya_vision/test_modeling_aya_vision.py @@ -62,7 +62,7 @@ class AyaVisionVisionText2TextModelTester: bos_token_id=0, eos_token_id=0, pad_token_id=0, - image_token_index=1, + image_token_index=2, num_channels=3, image_size=64, model_type="aya_vision", @@ -183,49 +183,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_config(self): self.config_tester.run_common_tests() - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip("Failing because of unique cache (HybridCache)") def test_model_outputs_equivalence(self, **kwargs): pass @@ -285,10 +242,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.") - def test_generate_continue_from_inputs_embeds(self): - pass - @unittest.skip("Failing because of unique cache (HybridCache)") def test_multi_gpu_data_parallel_forward(self): pass diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index f11583d7293..af95bbb2c32 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -20,7 +20,6 @@ import unittest import numpy as np import pytest import requests -from parameterized import parameterized from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( @@ -674,15 +673,6 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT # They should result in very similar logits torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) - @unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present") - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - def test_generate_from_inputs_embeds(self, _, num_beams): - pass - - @unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present") - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - # this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py class Blip2TextModelTester: diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 842c3ecddb6..fb5847fd602 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -355,49 +355,6 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit pixel_values = torch.cat([pixel_values, pixel_values], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @require_torch class ChameleonIntegrationTest(unittest.TestCase): diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py index 82844e3fd37..3cdfc062273 100644 --- a/tests/models/colpali/test_modeling_colpali.py +++ b/tests/models/colpali/test_modeling_colpali.py @@ -189,50 +189,6 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase): self.model_tester = ColPaliForRetrievalModelTester(self) self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @slow @require_vision def test_colpali_forward_inputs(self): diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index 9c9435fbe43..6d9780509b5 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -331,49 +331,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def test_config(self): self.config_tester.run_common_tests() - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip( "Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module" ) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 2c4643cf6c5..1f32df39749 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -131,10 +131,6 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") - def test_generate_continue_from_inputs_embeds(self): - pass - @unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.") def test_eager_matches_fa2_generate(self): pass diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 2f546e19e49..060bf15ea1e 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -39,13 +39,20 @@ from transformers.testing_utils import ( require_read_token, require_torch, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_sdpa_inference, + floats_tensor, + ids_tensor, +) from ..gemma.test_modeling_gemma import GemmaModelTester @@ -256,6 +263,7 @@ class Gemma3nTextModelTester(GemmaModelTester): vocab_size=99, vocab_size_per_layer_input=99, hidden_size=16, + hidden_size_per_layer_input=16, num_hidden_layers=4, # override to correctly test sharing cache pattern num_kv_shared_layers=2, # important to override layer_types=[ @@ -291,6 +299,7 @@ class Gemma3nTextModelTester(GemmaModelTester): self.vocab_size = vocab_size self.vocab_size_per_layer_input = vocab_size_per_layer_input self.hidden_size = hidden_size + self.hidden_size_per_layer_input = hidden_size_per_layer_input self.num_hidden_layers = num_hidden_layers self.num_kv_shared_layers = num_kv_shared_layers self.layer_types = layer_types @@ -317,7 +326,6 @@ class Gemma3nTextModelTester(GemmaModelTester): for_causal_lm_class = Gemma3nForCausalLM -@unittest.skip("Skipped for now!") @require_torch class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else () @@ -365,6 +373,64 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes [expected_shape] * len(iter_hidden_states), ) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + def test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + ): + "We need to relax a bit the `atols` for fp32 here due to the altup projections" + atols = { + ("cpu", False, torch.float32): 1e-3, # this was relaxed + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-3, # this was relaxed + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-3, # this was relaxed + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-3, # this was relaxed + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + _test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols + ) + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding" + ) + def test_contrastive_generate(self): + pass + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding" + ) + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding" + ) + def test_contrastive_generate_low_memory(self): + pass + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with dola decoding" + ) + def test_dola_decoding_sample(self): + pass + class Gemma3nVision2TextModelTester: text_config = {"activation_sparsity_pattern": None} diff --git a/tests/models/glm4v/test_modeling_glm4v.py b/tests/models/glm4v/test_modeling_glm4v.py index 4e444fb56e8..48d4a9b858e 100644 --- a/tests/models/glm4v/test_modeling_glm4v.py +++ b/tests/models/glm4v/test_modeling_glm4v.py @@ -13,12 +13,10 @@ # limitations under the License. """Testing suite for the PyTorch GLM-4.1V model.""" -import copy import gc import unittest import requests -from parameterized import parameterized from transformers import ( AutoProcessor, @@ -237,11 +235,6 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) def test_sdpa_can_dispatch_on_flash(self): pass - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - @unittest.skip("Cannot generate from inputs embeds with pixel values") - def test_generate_from_inputs_embeds(self): - pass - @unittest.skip(reason="Size mismatch") def test_multi_gpu_data_parallel_forward(self): pass @@ -250,34 +243,11 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) def test_model_is_small(self): pass - @unittest.skip("Cannot generate from inputs embeds with pixel values") + @unittest.skip("Error with compilation") def test_generate_from_inputs_embeds_with_static_cache(self): pass - # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test - # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels` - # TODO: @raushan - - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - del inputs["image_grid_thw"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - with torch.no_grad(): - model(**inputs)[0] - + # RoPE index doesn't match when using embeddings def test_inputs_embeds_matches_input_ids(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index 047d4a0da9b..87f182ac9cd 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -51,9 +51,6 @@ class GotOcr2VisionText2TextModelTester: num_channels=3, ignore_index=-100, image_size=64, - bos_token_id=0, - eos_token_id=0, - pad_token_id=0, image_token_index=1, model_type="got_ocr2", is_training=True, @@ -71,6 +68,9 @@ class GotOcr2VisionText2TextModelTester: "rope_theta": 10000, "mlp_ratio": 4, "tie_word_embeddings": True, + "bos_token_id": 2, + "eos_token_id": 3, + "pad_token_id": 4, }, vision_config={ "num_hidden_layers": 2, @@ -85,9 +85,9 @@ class GotOcr2VisionText2TextModelTester: ): self.parent = parent self.ignore_index = ignore_index - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] self.image_token_index = image_token_index self.model_type = model_type self.text_config = text_config @@ -109,9 +109,6 @@ class GotOcr2VisionText2TextModelTester: text_config=self.text_config, vision_config=self.vision_config, model_type=self.model_type, - bos_token_id=self.bos_token_id, - eos_token_id=self.eos_token_id, - pad_token_id=self.pad_token_id, image_token_index=self.image_token_index, ) @@ -127,7 +124,6 @@ class GotOcr2VisionText2TextModelTester: input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) - # input_ids[:, -1] = self.pad_token_id input_ids[input_ids == self.image_token_index] = self.pad_token_id input_ids[:, : self.num_image_tokens] = self.image_token_index @@ -181,55 +177,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - - @unittest.skip( - reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - @unittest.skip( reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format" ) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 1fbf788e377..94ab9491dfb 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -315,13 +315,6 @@ class IdeficsModelTester: def prepare_pixel_values(self): return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) - @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) - @unittest.skip(reason="Idefics has a hard requirement on SDPA, skipping this test") - def test_eager_matches_sdpa_inference( - self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels - ): - pass - @require_torch class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -611,6 +604,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMi def test_sdpa_can_dispatch_non_composite_models(self): pass + @unittest.skip(reason="Idefics can't do text-only inference") + def test_generate_from_random_inputs_embeds( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + pass + @require_torch class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase): @@ -899,6 +898,12 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni def test_generation_tester_mixin_inheritance(self): pass + @unittest.skip(reason="Idefics can't do text-only inference") + def test_generate_from_random_inputs_embeds( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + pass + @require_torch @require_vision diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 6ce19ddfade..7bd0656fff2 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -108,6 +108,7 @@ class Idefics2VisionText2TextModelTester: image_token_id=99, ): self.parent = parent + self.pad_token_id = text_config["pad_token_id"] self.is_training = is_training self.batch_size = batch_size self.num_images = num_images @@ -158,6 +159,7 @@ class Idefics2VisionText2TextModelTester: # For simplicity just set the last n tokens to the image token n_image_tokens_per_batch = self.num_images * self.perceiver_config["resampler_n_latents"] + input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id attention_mask = input_ids.ne(1).to(torch_device) inputs_dict = { diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 69a0f85acef..5cf06a50be1 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -96,6 +96,7 @@ class Idefics3VisionText2TextModelTester: image_token_id=57, ): self.parent = parent + self.pad_token_id = text_config["pad_token_id"] self.is_training = is_training self.batch_size = batch_size self.num_images = num_images @@ -148,6 +149,7 @@ class Idefics3VisionText2TextModelTester: # For simplicity just set the last n tokens to the image token n_image_tokens_per_batch = self.seq_length + input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id attention_mask = input_ids.ne(1).to(torch_device) inputs_dict = { diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index 66621fc0fe5..f4419cba5be 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -20,7 +20,6 @@ import unittest import numpy as np import pytest import requests -from parameterized import parameterized from transformers import ( CONFIG_MAPPING, @@ -522,12 +521,6 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene def test_model_get_set_embeddings(self): pass - @unittest.skip( - "InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -656,13 +649,6 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene # They should result in very similar logits torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) - @unittest.skip( - "InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present" - ) - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - def test_generate_from_inputs_embeds(self, _, num_beams): - pass - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): """ diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 17e6b0a64d7..d996ab778a5 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -20,7 +20,6 @@ import unittest import numpy as np import pytest from huggingface_hub import hf_hub_download -from parameterized import parameterized from transformers import ( CONFIG_MAPPING, @@ -535,12 +534,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( def test_model_common_attributes(self): pass - @unittest.skip( - "InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -669,13 +662,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( # They should result in very similar logits torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) - @unittest.skip( - "InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present" - ) - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - def test_generate_from_inputs_embeds(self, _, num_beams): - pass - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): """ diff --git a/tests/models/internvl/test_modeling_internvl.py b/tests/models/internvl/test_modeling_internvl.py index 19eb3cc4c71..7b5d6a29050 100644 --- a/tests/models/internvl/test_modeling_internvl.py +++ b/tests/models/internvl/test_modeling_internvl.py @@ -63,9 +63,6 @@ class InternVLVisionText2TextModelTester: image_seq_length=64, vision_feature_layer=-1, ignore_index=-100, - bos_token_id=0, - eos_token_id=0, - pad_token_id=0, image_token_id=1, num_channels=3, image_size=64, @@ -85,9 +82,9 @@ class InternVLVisionText2TextModelTester: "rope_theta": 10000, "mlp_ratio": 4, "tie_word_embeddings": True, - "bos_token_id": 0, - "eos_token_id": 0, - "pad_token_id": 0, + "bos_token_id": 3, + "eos_token_id": 4, + "pad_token_id": 5, }, vision_config={ "hidden_size": 32, @@ -103,9 +100,9 @@ class InternVLVisionText2TextModelTester: ): self.parent = parent self.ignore_index = ignore_index - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] self.image_token_id = image_token_id self.model_type = model_type self.text_config = text_config @@ -128,9 +125,6 @@ class InternVLVisionText2TextModelTester: text_config=self.text_config, vision_config=self.vision_config, model_type=self.model_type, - bos_token_id=self.bos_token_id, - eos_token_id=self.eos_token_id, - pad_token_id=self.pad_token_id, image_token_id=self.image_token_id, image_seq_length=self.image_seq_length, vision_feature_layer=self.vision_feature_layer, @@ -148,7 +142,6 @@ class InternVLVisionText2TextModelTester: input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) - # input_ids[:, -1] = self.pad_token_id input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[:, : self.image_seq_length] = self.image_token_id @@ -222,49 +215,6 @@ class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip(reason="Compile not yet supported because in LLava models") def test_sdpa_can_compile_dynamic(self): pass diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 5c142e066f9..254da46ac88 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -153,6 +153,7 @@ class JanusVisionText2TextModelTester: text_config=self.text_config, vision_config=self.vision_config, vq_config=self.get_vq_config(), + image_token_id=self.image_token_index, ) def prepare_config_and_inputs(self): @@ -200,50 +201,6 @@ class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, uni self.model_tester = JanusVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=JanusConfig, has_text_modality=False) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - del inputs["generation_mode"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # Overwrite inputs_embeds tests because we need to delete "pixel values" for VLMs. - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - del inputs["generation_mode"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - def test_sdpa_can_dispatch_composite_models(self): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 1e61d536d75..e23b8672d26 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -457,14 +457,6 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) - @pytest.mark.generate - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - @unittest.skip( - "KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten" - ) - def test_generate_from_inputs_embeds(self): - pass - @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @require_torch_sdpa @unittest.skip("KOSMOS-2 doesn't support padding") @@ -613,6 +605,53 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + @pytest.mark.generate + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + def test_generate_from_inputs_embeds(self, _, num_beams): + """Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc""" + # NOTE: overwritten because Kosmos with ids prepares position ids differently from embeds + # If the model get ids, all pad tokens are masked from position ids. That is not possible with embeds + + # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` + # if fails, you should probably update the `prepare_inputs_for_generation` function + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + config.is_decoder = True + + # Skip models without explicit support + model = model_class(config).to(torch_device).eval() + + # Traditional way of generating text + input_ids = inputs_dict.pop("input_ids") + input_ids[input_ids == config.get_text_config().pad_token_id] = 0 + generation_kwargs = { + "return_dict_in_generate": True, + "output_scores": True, + "num_beams": num_beams, + "do_sample": False, + "max_new_tokens": 5, + "min_new_tokens": 5, # generate exactly 5 tokens + "use_cache": True, + } + outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict) + self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5)) + + # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output). + # The output of the two calls should be the same. + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs_from_embeds = model.generate( + input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict + ) + self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) + + # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will + # be the same + outputs_from_embeds_wo_ids = model.generate( + inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict + ) + outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :] + self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds) + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py index b6c820b9acf..a7cd8701560 100644 --- a/tests/models/layoutlm/test_modeling_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_layoutlm.py @@ -243,7 +243,7 @@ class LayoutLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase if is_torch_available() else {} ) - fx_compatible = True + fx_compatible = False # Cannot support if `can_return_tuple` def setUp(self): self.model_tester = LayoutLMModelTester(self) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index edfbbe9f0e1..6900ce27977 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -196,49 +196,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM def test_config(self): self.config_tester.run_common_tests() - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - def test_mismatching_num_image_tokens(self): """ Tests that VLMs through an error with explicit message saying what is wrong diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index cb573913e4a..8c91176225f 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -222,49 +222,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - def test_mismatching_num_image_tokens(self): """ Tests that VLMs through an error with explicit message saying what is wrong diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 3f52d8291d7..352a3ef1915 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -86,7 +86,7 @@ class LlavaNextVideoVisionText2TextModelTester: "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, - "pad_token_id": 2, + "pad_token_id": 3, }, is_training=True, vision_config={ @@ -234,51 +234,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - del inputs["pixel_values_videos"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - del inputs["pixel_values_videos"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - def test_mismatching_num_image_tokens(self): """ Tests that VLMs through an error with explicit message saying what is wrong diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index f482f0a0680..2b134bf5a0e 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -230,49 +230,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - def test_odd_sized_image(self): # prepare model configuration config = self.model_tester.get_config() diff --git a/tests/models/mistral3/test_modeling_mistral3.py b/tests/models/mistral3/test_modeling_mistral3.py index 595044a6fd3..9f0e4ef6c53 100644 --- a/tests/models/mistral3/test_modeling_mistral3.py +++ b/tests/models/mistral3/test_modeling_mistral3.py @@ -57,9 +57,6 @@ class Mistral3VisionText2TextModelTester: image_seq_length=4, vision_feature_layer=-1, ignore_index=-100, - bos_token_id=0, - eos_token_id=0, - pad_token_id=0, image_token_index=1, num_channels=3, image_size=30, @@ -80,9 +77,9 @@ class Mistral3VisionText2TextModelTester: "rms_norm_eps": 1e-05, "rope_theta": 1000000000.0, "sliding_window": None, - "bos_token_id": 0, - "eos_token_id": 0, - "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 3, + "pad_token_id": 4, }, vision_config={ "model_type": "pixtral", @@ -98,9 +95,9 @@ class Mistral3VisionText2TextModelTester: ): self.parent = parent self.ignore_index = ignore_index - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] self.image_token_index = image_token_index self.model_type = model_type self.text_config = text_config @@ -209,49 +206,6 @@ class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @unittest.skip(reason="Compile not yet supported because in LLava models") def test_sdpa_can_compile_dynamic(self): pass diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 1be522f3a50..f75270283f1 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -199,49 +199,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes self.model_tester = PaliGemmaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens def test_mismatching_num_image_tokens(self): """ @@ -327,12 +284,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_feed_forward_chunking(self): pass - @unittest.skip( - reason="VLMs doesn't accept inputs embeds and pixel values at the same time. So if the test passed for backbone LM, it passes for VLM also" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index d40a6ec17e0..ba8d7a6cac6 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -183,49 +183,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe self.model_tester = PaliGemma2VisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens def test_mismatching_num_image_tokens(self): """ @@ -311,12 +268,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_feed_forward_chunking(self): pass - @unittest.skip( - reason="VLMs doesn't accept inputs embeds and pixel values at the same time. So if the test passed for backbone LM, it passes for VLM also" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index e075c412ef7..a8505605c48 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -22,7 +22,6 @@ from urllib.request import urlopen import librosa import requests -from parameterized import parameterized from transformers import ( AutoProcessor, @@ -289,10 +288,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene def test_sdpa_can_dispatch_on_flash(self): pass - @unittest.skip(reason="QwenOmniThinker does not use inputs_embeds") - def test_inputs_embeds(self): - pass - @unittest.skip(reason="QwenOmniThinker does not support output_hidden_states test") def test_model_outputs_equivalence(self): pass @@ -337,11 +332,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - @parameterized.expand([("greedy", 1), ("beam search", 2)]) - @unittest.skip("Cannot generate from inputs embeds") - def test_generate_from_inputs_embeds(self): - pass - @unittest.skip("Cannot do contrastive generation, has custom `generate()`") def test_contrastive_generate(self): pass diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 7894dc69806..019d3793333 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -184,7 +184,7 @@ class Qwen2_5_VLVisionText2TextModelTester: input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id inputs_dict = { "pixel_values": pixel_values, - "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), "input_ids": input_ids, "attention_mask": attention_mask, } @@ -357,39 +357,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test def test_model_is_small(self): pass - @unittest.skip( - reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model def test_prompt_lookup_decoding_matches_greedy_search(self): super().test_prompt_lookup_decoding_matches_greedy_search() - # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test - # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels` - # TODO: @raushan - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @require_torch class Qwen2_5_VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 72669fd390f..451f940ee00 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -176,7 +176,7 @@ class Qwen2VLVisionText2TextModelTester: inputs_dict = { "pixel_values": pixel_values, - "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), "input_ids": input_ids, "attention_mask": attention_mask, } @@ -313,35 +313,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_model_is_small(self): pass - @unittest.skip( - reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs" - ) - def test_generate_from_inputs_embeds_with_static_cache(self): - pass - - # The multimodal base model embeds will not match ids, due to pixel values. We can't change base test - # because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels` - # TODO: @raushan - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - @require_torch class Qwen2VLIntegrationTest(unittest.TestCase): diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index cdeb0d95ec1..280399eb6b8 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -181,14 +181,6 @@ class SmolVLMModelTest(ModelTesterMixin, unittest.TestCase): def test_config(self): self.config_tester.run_common_tests() - @unittest.skip(reason="input_embeds cannot be passed in without input_ids") - def test_inputs_embeds(): - pass - - @unittest.skip(reason="input_embeds cannot be passed in without input_ids") - def test_inputs_embeds_matches_input_ids(self): - pass - @unittest.skip(reason="Model does not support padding right") def test_flash_attn_2_inference_padding_right(self): pass @@ -347,10 +339,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste self.model_tester = SmolVLMVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=SmolVLMConfig, has_text_modality=False) - @unittest.skip(reason="input_embeds cannot be passed in without input_ids") - def test_inputs_embeds(): - pass - @unittest.skip(reason="Model does not support padding right") def test_flash_attn_2_inference_padding_right(self): pass @@ -394,14 +382,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Unsupported") - def test_generate_from_inputs_embeds_0_greedy(self): - pass - - @unittest.skip(reason="Unsupported") - def test_generate_from_inputs_embeds_1_beam_search(self): - pass - @unittest.skip(reason="Unsupported") def test_generate_with_static_cache(self): pass diff --git a/tests/models/splinter/test_modeling_splinter.py b/tests/models/splinter/test_modeling_splinter.py index 795adb86b30..f8a8121c40d 100644 --- a/tests/models/splinter/test_modeling_splinter.py +++ b/tests/models/splinter/test_modeling_splinter.py @@ -372,6 +372,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase with torch.no_grad(): _ = model(**self._prepare_for_class(inputs_dict, model_class)) + @unittest.skip( + "Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + "Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + @require_torch class SplinterModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 39196f2b1c2..4c9e4ff3ceb 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -344,51 +344,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe continue recursive_check(model_batched_output[key], model_row_output[key], model_name, key) - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values_images"] - del inputs["pixel_values_videos"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values_images"] - del inputs["pixel_values_videos"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - def test_mismatching_num_image_tokens(self): """ Tests that VLMs through an error with explicit message saying what is wrong diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 07d9ab3c53e..65580977556 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -192,49 +192,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest def test_config(self): self.config_tester.run_common_tests() - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - def test_inputs_embeds(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - wte = model.get_input_embeddings() - inputs["inputs_embeds"] = wte(input_ids) - - with torch.no_grad(): - model(**inputs) - - # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs - # while some other models require pixel_values to be present - def test_inputs_embeds_matches_input_ids(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - input_ids = inputs["input_ids"] - del inputs["input_ids"] - del inputs["pixel_values"] - - inputs_embeds = model.get_input_embeddings()(input_ids) - - with torch.no_grad(): - out_ids = model(input_ids=input_ids, **inputs)[0] - out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] - torch.testing.assert_close(out_embeds, out_ids) - # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens def test_mismatching_num_image_tokens(self): """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d7a41a6c5d0..0587c73bd9b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -156,6 +156,334 @@ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [ ] + [("fp32_pad_left_output_attentions", "fp32", "left", True, True, False)] +def _test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + atols=None, + rtols=None, +): + """ + This test is written as a regular function to be able to overload it easily with different tolerances. + Otherwise, `paramterezie.expand` prevents it as it removes the original function from the namespace. + """ + # TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like + # models have a custom mixin, which we detect to skip this test. + if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__): + self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`") + + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + # convert shorthand name to torch.dtype + if torch_dtype == "fp16": + torch_dtype = torch.float16 + elif torch_dtype == "bf16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "fp32": + torch_dtype = torch.float32 + + if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16: + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16: + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype) + if atols is None: + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + if rtols is None: + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, # (different from others) + ("cuda", True, torch.float16): 5e-3, + } + + set_model_tester_for_less_flaky_test(self) + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + set_config_for_less_flaky_test(config) + model = model_class(config) + # TODO: standardize the interfaces for musicgen models, see other todo in this test + if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration": + is_encoder_decoder = True + else: + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_from_pretrained_kwargs = { + "pretrained_model_name_or_path": tmpdirname, + "torch_dtype": torch_dtype, + } + + if hasattr(config, "use_mask_token") or "use_mask_token" in inspect.signature(model.__init__).parameters: + model_from_pretrained_kwargs["use_mask_token"] = True + + # TODO: remove this try/except, models should have a shared API + try: + model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="sdpa") + except ValueError: + model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs) + model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) + + model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) + + set_model_for_less_flaky_test(model_eager) + set_model_for_less_flaky_test(model_sdpa) + + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + self.skipTest(reason="Model does not support output_attentions") + + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + # musicgen decoder models; TODO: find better abstraction + if ( + model.__class__.__name__.startswith("Musicgen") + and hasattr(self.model_tester, "num_codebooks") + and not hasattr(model_eager, "text_encoder") + ): + input_data_batch_size = batch_size * self.model_tester.num_codebooks + else: + input_data_batch_size = batch_size + + processed_inputs = {} + processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name] + + for key in getattr(self, "additional_model_inputs", []): + # Some models don't have all `additional_model_inputs`, especially when we + # craft cases to test model in different settings + if key in inputs_dict: + processed_inputs[key] = inputs_dict[key] + + for key, value in processed_inputs.items(): + if torch.is_floating_point(value): + value = value.to(torch_dtype) + + # extend value to have at least `input_data_batch_size` elements + if value.shape[0] < input_data_batch_size: + size = (input_data_batch_size - value.shape[0], *value.shape[1:]) + if torch.is_floating_point(value): + extension = torch.rand(size=size, dtype=value.dtype, device=torch_device) + else: + extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device) + value = torch.cat((value, extension), dim=0).to(torch_device) + + processed_inputs[key] = value[:input_data_batch_size] + + if not use_attention_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]).shape[ + -1 + ] + else: + seqlen = processed_inputs[model.main_input_name].shape[-1] + dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + + # extend dummy_attention_mask to have at least `batch_size` elements + if dummy_attention_mask.shape[0] < batch_size: + size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:]) + extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + + dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :2] = 0 + dummy_attention_mask[-1, 2:] = 1 + elif padding_side == "right": + dummy_attention_mask[-1, -2:] = 0 + dummy_attention_mask[-1, :-2] = 1 + + if is_encoder_decoder: + # musicgen encoder-decoder models; TODO: find better abstraction + if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"): + input_data_batch_size = batch_size * self.model_tester.num_codebooks + else: + input_data_batch_size = batch_size + + decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]) + decoder_input_ids = decoder_input_ids[:input_data_batch_size] + if decoder_input_ids.shape[0] != input_data_batch_size: + extension = torch.ones( + input_data_batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + processed_inputs.update( + { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + ) + else: + processed_inputs.update( + { + "output_hidden_states": True, + } + ) + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters: + processed_inputs["output_attentions"] = output_attentions + if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: + dummy_mask = torch.ones((self.model_tester.num_masks,)) + + # In case of additional token (like class) we define a custom `mask_length` + if hasattr(self.model_tester, "mask_length"): + mask_length = self.model_tester.mask_length - dummy_mask.size(0) + else: + mask_length = self.model_tester.seq_length - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with sdpa_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + prepared_inputs = { + k: v.to(torch_device) if isinstance(v, torch.Tensor) else v for k, v in prepared_inputs.items() + } + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + if "logits_per_text" in outputs_eager: + key = "logits_per_text" + elif "vision_hidden_states" in outputs_eager: + key = "vision_hidden_states" + elif "audio_values" in outputs_eager: + key = "audio_values" + elif "decoder_hidden_states" in outputs_eager: + key = "decoder_hidden_states" + elif "logits" in outputs_eager and "Classification" in model_class.__name__: + key = "logits" + elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower(): + outputs_eager = outputs_eager["language_model_outputs"] + outputs_sdpa = outputs_sdpa["language_model_outputs"] + key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states" + else: + key = "hidden_states" + + # TODO: rename logits -> hidden_states + logits_eager = outputs_eager[key] + logits_sdpa = outputs_sdpa[key] + + if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]: + logits_eager = logits_eager[-1] + logits_sdpa = logits_sdpa[-1] + + if key == "logits_per_text": + nan_mask = torch.isnan(logits_eager) + logits_eager[nan_mask] = 0 + logits_sdpa[nan_mask] = 0 + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "hpu": + atol = atols["cuda", enable_kernels, torch_dtype] + rtol = rtols["cuda", enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_attention_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() + raise ValueError( + f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = " + f"{rtol}" + ) + + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) for key in configs_no_init.__dict__.keys(): @@ -2501,7 +2829,9 @@ class ModelTesterMixin: self.skipTest(reason="This model doesn't use `inputs_embeds`") inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 1 + ) wte = model.get_input_embeddings() if not self.is_encoder_decoder: @@ -3405,321 +3735,9 @@ class ModelTesterMixin: def test_eager_matches_sdpa_inference( self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels ): - # TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like - # models have a custom mixin, which we detect to skip this test. - if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__): - self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`") - - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - if not self.all_model_classes[0]._supports_sdpa: - self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") - - # convert shorthand name to torch.dtype - if torch_dtype == "fp16": - torch_dtype = torch.float16 - elif torch_dtype == "bf16": - torch_dtype = torch.bfloat16 - elif torch_dtype == "fp32": - torch_dtype = torch.float32 - - if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16: - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") - - if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16: - self.skipTest( - f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" - ) - - # Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype) - atols = { - ("cpu", False, torch.float32): 1e-6, - ("cpu", False, torch.float16): 5e-3, - ("cpu", False, torch.bfloat16): 1e-2, - ("cpu", True, torch.float32): 1e-6, - ("cpu", True, torch.float16): 5e-3, - ("cpu", True, torch.bfloat16): 1e-2, - ("cuda", False, torch.float32): 1e-6, - ("cuda", False, torch.bfloat16): 1e-2, - ("cuda", False, torch.float16): 5e-3, - ("cuda", True, torch.float32): 1e-6, - ("cuda", True, torch.bfloat16): 1e-2, - ("cuda", True, torch.float16): 5e-3, - } - rtols = { - ("cpu", False, torch.float32): 1e-4, - ("cpu", False, torch.float16): 5e-3, - ("cpu", False, torch.bfloat16): 1e-2, - ("cpu", True, torch.float32): 1e-4, - ("cpu", True, torch.float16): 5e-3, - ("cpu", True, torch.bfloat16): 1e-2, - ("cuda", False, torch.float32): 1e-4, - ("cuda", False, torch.bfloat16): 1e-2, - ("cuda", False, torch.float16): 5e-3, - ("cuda", True, torch.float32): 1e-4, - ("cuda", True, torch.bfloat16): 3e-2, # (different from others) - ("cuda", True, torch.float16): 5e-3, - } - - set_model_tester_for_less_flaky_test(self) - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - set_config_for_less_flaky_test(config) - model = model_class(config) - # TODO: standardize the interfaces for musicgen models, see other todo in this test - if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration": - is_encoder_decoder = True - else: - is_encoder_decoder = model.config.is_encoder_decoder - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_from_pretrained_kwargs = { - "pretrained_model_name_or_path": tmpdirname, - "torch_dtype": torch_dtype, - } - - if ( - hasattr(config, "use_mask_token") - or "use_mask_token" in inspect.signature(model.__init__).parameters - ): - model_from_pretrained_kwargs["use_mask_token"] = True - - # TODO: remove this try/except, models should have a shared API - try: - model_sdpa = model_class.from_pretrained( - **model_from_pretrained_kwargs, attn_implementation="sdpa" - ) - except ValueError: - model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs) - model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) - - model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") - model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) - - set_model_for_less_flaky_test(model_eager) - set_model_for_less_flaky_test(model_sdpa) - - can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters - if not (self.has_attentions and can_output_attn) and output_attentions: - self.skipTest(reason="Model does not support output_attentions") - - # TODO: if we can also check with `batch_size=1` without being flaky? - for batch_size in [7]: - # musicgen decoder models; TODO: find better abstraction - if ( - model.__class__.__name__.startswith("Musicgen") - and hasattr(self.model_tester, "num_codebooks") - and not hasattr(model_eager, "text_encoder") - ): - input_data_batch_size = batch_size * self.model_tester.num_codebooks - else: - input_data_batch_size = batch_size - - processed_inputs = {} - processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name] - - for key in getattr(self, "additional_model_inputs", []): - # Some models don't have all `additional_model_inputs`, especially when we - # craft cases to test model in different settings - if key in inputs_dict: - processed_inputs[key] = inputs_dict[key] - - for key, value in processed_inputs.items(): - if torch.is_floating_point(value): - value = value.to(torch_dtype) - - # extend value to have at least `input_data_batch_size` elements - if value.shape[0] < input_data_batch_size: - size = (input_data_batch_size - value.shape[0], *value.shape[1:]) - if torch.is_floating_point(value): - extension = torch.rand(size=size, dtype=value.dtype, device=torch_device) - else: - extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device) - value = torch.cat((value, extension), dim=0).to(torch_device) - - processed_inputs[key] = value[:input_data_batch_size] - - if not use_attention_mask: - dummy_attention_mask = None - else: - dummy_attention_mask = inputs_dict.get("attention_mask", None) - if dummy_attention_mask is None: - if is_encoder_decoder: - seqlen = inputs_dict.get( - "decoder_input_ids", processed_inputs[model.main_input_name] - ).shape[-1] - else: - seqlen = processed_inputs[model.main_input_name].shape[-1] - dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) - - # extend dummy_attention_mask to have at least `batch_size` elements - if dummy_attention_mask.shape[0] < batch_size: - size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:]) - extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device) - dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) - - dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device) - - dummy_attention_mask[:] = 1 - if padding_side == "left": - dummy_attention_mask[-1, :2] = 0 - dummy_attention_mask[-1, 2:] = 1 - elif padding_side == "right": - dummy_attention_mask[-1, -2:] = 0 - dummy_attention_mask[-1, :-2] = 1 - - if is_encoder_decoder: - # musicgen encoder-decoder models; TODO: find better abstraction - if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"): - input_data_batch_size = batch_size * self.model_tester.num_codebooks - else: - input_data_batch_size = batch_size - - decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]) - decoder_input_ids = decoder_input_ids[:input_data_batch_size] - if decoder_input_ids.shape[0] != input_data_batch_size: - extension = torch.ones( - input_data_batch_size - decoder_input_ids.shape[0], - *decoder_input_ids.shape[1:], - dtype=decoder_input_ids.dtype, - device=torch_device, - ) - decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) - decoder_input_ids = decoder_input_ids.to(torch_device) - - # TODO: never an `attention_mask` arg here? - processed_inputs.update( - { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - ) - else: - processed_inputs.update( - { - "output_hidden_states": True, - } - ) - - # Otherwise fails for e.g. WhisperEncoderModel - if "attention_mask" in inspect.signature(model_eager.forward).parameters: - processed_inputs["attention_mask"] = dummy_attention_mask - - if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters: - processed_inputs["output_attentions"] = output_attentions - if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: - dummy_mask = torch.ones((self.model_tester.num_masks,)) - - # In case of additional token (like class) we define a custom `mask_length` - if hasattr(self.model_tester, "mask_length"): - mask_length = self.model_tester.mask_length - dummy_mask.size(0) - else: - mask_length = self.model_tester.seq_length - dummy_mask.size(0) - dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) - dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() - processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) - - if "noise" in inspect.signature(model_eager.forward).parameters: - np.random.seed(2) - num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2) - noise = np.random.uniform(size=(batch_size, num_patches)) - processed_inputs["noise"] = torch.from_numpy(noise) - - # TODO: test gradients as well (& for FA2 as well!) - with torch.no_grad(): - with sdpa_kernel( - enable_flash=enable_kernels, - enable_math=True, - enable_mem_efficient=enable_kernels, - ): - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - prepared_inputs = { - k: v.to(torch_device) if isinstance(v, torch.Tensor) else v - for k, v in prepared_inputs.items() - } - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - - if "logits_per_text" in outputs_eager: - key = "logits_per_text" - elif "vision_hidden_states" in outputs_eager: - key = "vision_hidden_states" - elif "audio_values" in outputs_eager: - key = "audio_values" - elif "decoder_hidden_states" in outputs_eager: - key = "decoder_hidden_states" - elif "logits" in outputs_eager and "Classification" in model_class.__name__: - key = "logits" - elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower(): - outputs_eager = outputs_eager["language_model_outputs"] - outputs_sdpa = outputs_sdpa["language_model_outputs"] - key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states" - else: - key = "hidden_states" - - # TODO: rename logits -> hidden_states - logits_eager = outputs_eager[key] - logits_sdpa = outputs_sdpa[key] - - if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]: - logits_eager = logits_eager[-1] - logits_sdpa = logits_sdpa[-1] - - if key == "logits_per_text": - nan_mask = torch.isnan(logits_eager) - logits_eager[nan_mask] = 0 - logits_sdpa[nan_mask] = 0 - - if torch_device in ["cpu", "cuda"]: - atol = atols[torch_device, enable_kernels, torch_dtype] - rtol = rtols[torch_device, enable_kernels, torch_dtype] - elif torch_device == "hpu": - atol = atols["cuda", enable_kernels, torch_dtype] - rtol = rtols["cuda", enable_kernels, torch_dtype] - elif torch_device == "xpu": - # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH - # which is implemented on PyTorch level using aten operators and is - # device agnostic with respect to implementation of each aten operator. - atol = atols["cuda", False, torch_dtype] - rtol = rtols["cuda", False, torch_dtype] - else: - atol = 1e-7 - rtol = 1e-4 - - # Masked tokens output slightly deviates - we don't mind that. - if use_attention_mask: - _logits_sdpa = torch.zeros_like(input=logits_sdpa) - _logits_eager = torch.zeros_like(input=logits_eager) - - _logits_sdpa[:-1] = logits_sdpa[:-1] - _logits_eager[:-1] = logits_eager[:-1] - - if padding_side == "left": - _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] - _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] - - elif padding_side == "right": - _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] - _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] - - logits_sdpa = _logits_sdpa - logits_eager = _logits_eager - - results = [ - torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) - for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) - ] - # If 80% batch elements have matched results, it's fine - if np.mean(results) < 0.8: - mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() - raise ValueError( - f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = " - f"{rtol}" - ) + _test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ) @require_torch_sdpa @require_torch_accelerator diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 04fb04a6473..8058558b407 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -276,6 +276,9 @@ SPECIAL_CASES_TO_ALLOW = { "attention_chunk_size", ], "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], + # position_embedding_type not used and deprecated. Should be deleted in v4.55 + "LayoutLMConfig": ["position_embedding_type"], + "MarkupLMConfig": ["position_embedding_type"], "SmolLM3Config": ["no_rope_layer_interval"], "Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm` }