🚨 Don't use cache in non-generative models (#38751)

* deprecate for 1 version

* style

* fix some tests

* fix esm

* skip for now, GC requires positional args but we have keyword args

* remove transpose for scores in modified models only

* skip fx trace tests
This commit is contained in:
Raushan Turganbay 2025-07-01 11:08:21 +02:00 committed by GitHub
parent dbc98328da
commit e435574721
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 969 additions and 2328 deletions

View File

@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.utils.checkpoint
@ -25,14 +25,15 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithNoAttention,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
@ -90,7 +91,7 @@ class AlignOutput(ModelOutput):
The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The output of [`AlignVisionModel`].
text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`):
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`AlignTextModel`].
vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
The output of the [`AlignVisionModel`].
@ -101,7 +102,7 @@ class AlignOutput(ModelOutput):
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
def to_tuple(self) -> tuple[Any]:
@ -508,7 +509,6 @@ class AlignVisionEncoder(nn.Module):
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText
class AlignTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
@ -537,7 +537,6 @@ class AlignTextEmbeddings(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
@ -547,7 +546,7 @@ class AlignTextEmbeddings(nn.Module):
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -573,9 +572,35 @@ class AlignTextEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class AlignTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -583,6 +608,7 @@ class AlignTextSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -592,20 +618,12 @@ class AlignTextSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -615,96 +633,33 @@ class AlignTextSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
head_mask=head_mask,
**kwargs,
)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
@ -723,18 +678,10 @@ class AlignTextSelfOutput(nn.Module):
return hidden_states
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
"eager": AlignTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
class AlignTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.self = AlignTextSelfAttention(config)
self.output = AlignTextSelfOutput(config)
self.pruned_heads = set()
@ -756,6 +703,9 @@ class AlignTextAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -765,15 +715,14 @@ class AlignTextAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -811,22 +760,18 @@ class AlignTextOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
class AlignTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = AlignTextAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = AlignTextAttention(config, position_embedding_type="absolute")
self.intermediate = AlignTextIntermediate(config)
self.output = AlignTextOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -836,60 +781,23 @@ class AlignTextLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -898,14 +806,18 @@ class AlignTextLayer(GradientCheckpointingLayer):
return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText
class AlignTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -918,65 +830,36 @@ class AlignTextEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -1052,6 +935,7 @@ class AlignTextModel(AlignPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1059,12 +943,13 @@ class AlignTextModel(AlignPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
**kwargs,
) -> Union[tuple, BaseModelOutputWithPooling]:
r"""
Examples:
@ -1133,20 +1018,17 @@ class AlignTextModel(AlignPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
**kwargs,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@ -1180,6 +1062,7 @@ class AlignVisionModel(AlignPreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.convolution
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1219,7 +1102,7 @@ class AlignVisionModel(AlignPreTrainedModel):
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
# Apply pooling
last_hidden_state = encoder_outputs[0]
@ -1227,9 +1110,6 @@ class AlignVisionModel(AlignPreTrainedModel):
# Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
pooled_output = pooled_output.reshape(pooled_output.shape[:2])
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
@ -1369,6 +1249,7 @@ class AlignModel(AlignPreTrainedModel):
return image_features
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1419,7 +1300,7 @@ class AlignModel(AlignPreTrainedModel):
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
text_outputs = self.text_model(
@ -1431,7 +1312,7 @@ class AlignModel(AlignPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
image_embeds = vision_outputs[1]
@ -1450,10 +1331,6 @@ class AlignModel(AlignPreTrainedModel):
if return_loss:
loss = align_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return AlignOutput(
loss=loss,
logits_per_image=logits_per_image,

View File

@ -26,14 +26,14 @@ from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutputWithPoolingAndProjection,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
@ -180,7 +180,6 @@ class AltRobertaEmbeddings(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape)
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta
class AltRobertaSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
@ -206,13 +205,9 @@ class AltRobertaSelfAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -223,55 +218,19 @@ class AltRobertaSelfAttention(nn.Module):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
@ -310,8 +269,6 @@ class AltRobertaSelfAttention(nn.Module):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
@ -335,7 +292,6 @@ ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
}
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
class AltRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
@ -363,6 +319,9 @@ class AltRobertaAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -375,12 +334,9 @@ class AltRobertaAttention(nn.Module):
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -418,22 +374,19 @@ class AltRobertaOutput(nn.Module):
return hidden_states
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta
class AltRobertaLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = AltRobertaAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute")
self.intermediate = AltRobertaIntermediate(config)
self.output = AltRobertaOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -443,60 +396,23 @@ class AltRobertaLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -505,14 +421,19 @@ class AltRobertaLayer(GradientCheckpointingLayer):
return layer_output
# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta
class AltRobertaEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -525,65 +446,36 @@ class AltRobertaEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -787,6 +679,7 @@ class AltCLIPEncoder(nn.Module):
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -853,8 +746,6 @@ class AltCLIPEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
@ -1008,6 +899,7 @@ class AltCLIPVisionTransformer(nn.Module):
self.encoder = AltCLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1033,16 +925,13 @@ class AltCLIPVisionTransformer(nn.Module):
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
@ -1106,16 +995,11 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
@auto_docstring(
custom_intro="""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
The model behaves as an encoder following the architecture described in *Attention is
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
"""
)
class AltRobertaModel(AltCLIPPreTrainedModel):
@ -1152,6 +1036,10 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@auto_docstring
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
@ -1176,11 +1064,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -1194,11 +1077,8 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
@ -1212,21 +1092,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
@ -1235,33 +1100,23 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@ -1284,6 +1139,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
return super().resize_token_embeddings(new_num_tokens)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1326,11 +1184,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
# last module outputs
@ -1343,9 +1199,6 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
projection_state = self.transformation(sequence_output)
pooler_output = projection_state[:, 0]
if not return_dict:
return (projection_state, pooler_output) + outputs[2:4]
return BaseModelOutputWithPoolingAndProjection(
last_hidden_state=projection_state,
pooler_output=pooler_output,

View File

@ -1026,7 +1026,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@auto_docstring
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -26,13 +26,14 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_bros import BrosConfig
@ -150,7 +151,6 @@ class BrosTextEmbeddings(nn.Module):
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
@ -160,7 +160,7 @@ class BrosTextEmbeddings(nn.Module):
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
@ -208,14 +208,7 @@ class BrosSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -227,42 +220,21 @@ class BrosSelfAttention(nn.Module):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[torch.Tensor] = False,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
if is_cross_attention:
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@ -317,7 +289,7 @@ class BrosSelfAttention(nn.Module):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
outputs = outputs + (None,)
return outputs
@ -364,6 +336,7 @@ class BrosAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -382,7 +355,6 @@ class BrosAttention(nn.Module):
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
@ -435,6 +407,7 @@ class BrosLayer(GradientCheckpointingLayer):
self.intermediate = BrosIntermediate(config)
self.output = BrosOutput(config)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -446,50 +419,38 @@ class BrosLayer(GradientCheckpointingLayer):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if hasattr(self, "crossattention"):
raise Exception(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
@ -500,7 +461,7 @@ class BrosLayer(GradientCheckpointingLayer):
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
outputs = outputs + (None,)
return outputs
@ -516,6 +477,9 @@ class BrosEncoder(nn.Module):
self.config = config
self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -529,33 +493,28 @@ class BrosEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
bbox_pos_emb,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
hidden_states=hidden_states,
bbox_pos_emb=bbox_pos_emb,
attention_mask=attention_mask,
head_mask=layer_head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
@ -564,21 +523,8 @@ class BrosEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
@ -689,6 +635,9 @@ class BrosModel(BrosPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -736,11 +685,6 @@ class BrosModel(BrosPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -756,9 +700,6 @@ class BrosModel(BrosPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
@ -797,7 +738,6 @@ class BrosModel(BrosPreTrainedModel):
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
@ -813,22 +753,16 @@ class BrosModel(BrosPreTrainedModel):
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
@ -852,6 +786,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
self.init_weights()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -908,7 +843,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -927,10 +862,6 @@ class BrosForTokenClassification(BrosPreTrainedModel):
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
@ -976,6 +907,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
self.init_weights()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1037,7 +969,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
last_hidden_states = outputs[0]
@ -1082,10 +1014,6 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
loss = initial_token_loss + subsequent_token_loss
if not return_dict:
output = (initial_token_logits, subsequent_token_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return BrosSpadeOutput(
loss=loss,
initial_token_logits=initial_token_logits,
@ -1118,6 +1046,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
self.init_weights()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1173,7 +1102,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
last_hidden_states = outputs[0]
@ -1203,10 +1132,6 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,

View File

@ -14,9 +14,8 @@
# limitations under the License.
"""PyTorch Chinese-CLIP model."""
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.utils.checkpoint
@ -26,13 +25,13 @@ from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
@ -90,7 +89,7 @@ class ChineseCLIPOutput(ModelOutput):
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText
# Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP
class ChineseCLIPTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
@ -119,7 +118,6 @@ class ChineseCLIPTextEmbeddings(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
@ -129,7 +127,7 @@ class ChineseCLIPTextEmbeddings(nn.Module):
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
position_ids = self.position_ids[:, :seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -239,9 +237,37 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText
# Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP
class ChineseCLIPTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -249,6 +275,7 @@ class ChineseCLIPTextSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -258,20 +285,12 @@ class ChineseCLIPTextSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -281,96 +300,33 @@ class ChineseCLIPTextSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
head_mask=head_mask,
**kwargs,
)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
@ -389,18 +345,11 @@ class ChineseCLIPTextSelfOutput(nn.Module):
return hidden_states
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
"eager": ChineseCLIPTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP
class ChineseCLIPTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.self = ChineseCLIPTextSelfAttention(config)
self.output = ChineseCLIPTextSelfOutput(config)
self.pruned_heads = set()
@ -422,6 +371,9 @@ class ChineseCLIPTextAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -431,15 +383,14 @@ class ChineseCLIPTextAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -468,66 +419,37 @@ class ChineseCLIPVisionAttention(nn.Module):
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False,
self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
None,
dropout=0.0 if not self.training else self.dropout,
scaling=1.0,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
return attn_output, attn_weights
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
@ -577,22 +499,19 @@ class ChineseCLIPVisionMLP(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP
class ChineseCLIPTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ChineseCLIPTextAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
self.intermediate = ChineseCLIPTextIntermediate(config)
self.output = ChineseCLIPTextOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -602,60 +521,23 @@ class ChineseCLIPTextLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -777,14 +659,19 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
class ChineseCLIPTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -797,65 +684,36 @@ class ChineseCLIPTextEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -874,6 +732,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -922,8 +781,6 @@ class ChineseCLIPVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
@ -940,6 +797,7 @@ class ChineseCLIPVisionTransformer(nn.Module):
self.encoder = ChineseCLIPVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -965,16 +823,13 @@ class ChineseCLIPVisionTransformer(nn.Module):
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
@ -1034,6 +889,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1050,18 +906,13 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -1093,56 +944,28 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@ -1343,6 +1166,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
return image_features
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1392,7 +1216,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
return_dict=True,
)
text_outputs = self.text_model(
@ -1402,7 +1226,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
image_embeds = vision_outputs[1]
@ -1424,14 +1248,6 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
if return_loss:
loss = chinese_clip_loss(logits_per_text)
if not return_dict:
# fix the None pooled_output of text_outputs to conform with dict_output
pooled_output = text_outputs[1]
if pooled_output is None:
text_outputs = (text_outputs[0],) + text_outputs[2:]
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return ChineseCLIPOutput(
loss=loss,
logits_per_image=logits_per_image,

View File

@ -17,7 +17,7 @@
import collections
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.nn.functional as F
@ -26,13 +26,14 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutput,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import ModelOutput, auto_docstring, logging, torch_int
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
@ -399,11 +400,6 @@ class ClapAudioSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
@ -412,11 +408,11 @@ class ClapAudioSelfAttention(nn.Module):
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@ -1090,9 +1086,37 @@ class ClapTextEmbeddings(nn.Module):
return position_ids.unsqueeze(0).expand(input_shape)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText
# Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
class ClapTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -1100,6 +1124,7 @@ class ClapTextSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -1109,20 +1134,12 @@ class ClapTextSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -1132,96 +1149,33 @@ class ClapTextSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
head_mask=head_mask,
**kwargs,
)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
@ -1240,18 +1194,11 @@ class ClapTextSelfOutput(nn.Module):
return hidden_states
CLAP_TEXT_SELF_ATTENTION_CLASSES = {
"eager": ClapTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
class ClapTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.self = ClapTextSelfAttention(config)
self.output = ClapTextSelfOutput(config)
self.pruned_heads = set()
@ -1273,6 +1220,9 @@ class ClapTextAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -1282,15 +1232,14 @@ class ClapTextAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -1328,22 +1277,19 @@ class ClapTextOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
class ClapTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ClapTextAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ClapTextAttention(config, position_embedding_type="absolute")
self.intermediate = ClapTextIntermediate(config)
self.output = ClapTextOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -1353,60 +1299,23 @@ class ClapTextLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -1415,14 +1324,19 @@ class ClapTextLayer(GradientCheckpointingLayer):
return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
class ClapTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -1435,65 +1349,36 @@ class ClapTextEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -1643,6 +1528,11 @@ class ClapTextModel(ClapPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1666,11 +1556,6 @@ class ClapTextModel(ClapPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -1684,11 +1569,8 @@ class ClapTextModel(ClapPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
@ -1702,21 +1584,6 @@ class ClapTextModel(ClapPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
@ -1725,33 +1592,23 @@ class ClapTextModel(ClapPreTrainedModel):
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@ -1892,6 +1749,7 @@ class ClapModel(ClapPreTrainedModel):
return audio_features
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1947,7 +1805,7 @@ class ClapModel(ClapPreTrainedModel):
is_longer=is_longer,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
text_outputs = self.text_model(
@ -1956,7 +1814,7 @@ class ClapModel(ClapPreTrainedModel):
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
@ -1981,10 +1839,6 @@ class ClapModel(ClapPreTrainedModel):
audio_loss = contrastive_loss(logits_per_audio.t())
loss = (caption_loss + audio_loss) / 2.0
if not return_dict:
output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs)
return ((loss,) + output) if loss is not None else output
return ClapOutput(
loss=loss,
logits_per_audio=logits_per_audio,
@ -2013,6 +1867,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
def set_input_embeddings(self, value):
self.text_model.embeddings.word_embeddings = value
@can_return_tuple
@auto_docstring
def forward(
self,
@ -2045,17 +1900,13 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
text_embeds = self.text_projection(pooled_output)
if not return_dict:
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return tuple(output for output in outputs if output is not None)
return ClapTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
@ -2079,6 +1930,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.audio_model.audio_encoder.patch_embed.proj
@can_return_tuple
@auto_docstring
def forward(
self,
@ -2123,17 +1975,13 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
is_longer=is_longer,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
audio_embeds = self.audio_projection(pooled_output)
if not return_dict:
outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:]
return tuple(output for output in outputs if output is not None)
return ClapAudioModelOutput(
audio_embeds=audio_embeds,
last_hidden_state=audio_outputs.last_hidden_state,

View File

@ -28,7 +28,7 @@ from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging, torch_int
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
@ -490,6 +490,7 @@ class CLIPSegEncoder(nn.Module):
self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -555,8 +556,6 @@ class CLIPSegEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)

View File

@ -45,6 +45,7 @@ from ...modeling_outputs import (
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
from ...utils.deprecation import deprecate_kwarg
from .configuration_data2vec_audio import Data2VecAudioConfig
@ -240,6 +241,7 @@ class Data2VecAudioAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -247,7 +249,7 @@ class Data2VecAudioAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -268,42 +270,9 @@ class Data2VecAudioAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -325,7 +294,7 @@ class Data2VecAudioAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class Data2VecAudioFeedForward(nn.Module):

View File

@ -634,7 +634,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@auto_docstring
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -405,11 +405,6 @@ class DonutSwinSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
@ -418,11 +413,11 @@ class DonutSwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

View File

@ -26,14 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_esm import EsmConfig
@ -187,12 +188,16 @@ class EsmEmbeddings(nn.Module):
self.mask_token_id = config.mask_token_id
def forward(
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
self,
input_ids=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
@ -281,11 +286,7 @@ class EsmSelfAttention(nn.Module):
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -296,32 +297,22 @@ class EsmSelfAttention(nn.Module):
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
if is_cross_attention:
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
@ -329,16 +320,6 @@ class EsmSelfAttention(nn.Module):
# ESM code and fix rotary embeddings.
query_layer = query_layer * self.attention_head_size**-0.5
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
@ -385,7 +366,7 @@ class EsmSelfAttention(nn.Module):
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
outputs = outputs + (None,)
return outputs
@ -418,6 +399,7 @@ class EsmFlashAttention2(EsmSelfAttention):
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
self.dropout_prob = config.attention_probs_dropout_prob
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -441,7 +423,6 @@ class EsmFlashAttention2(EsmSelfAttention):
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
@ -450,9 +431,6 @@ class EsmFlashAttention2(EsmSelfAttention):
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
if past_key_value is not None:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
@ -514,7 +492,7 @@ class EsmFlashAttention2(EsmSelfAttention):
outputs = (attn_output, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
outputs = outputs + (None,)
return outputs
@ -551,6 +529,7 @@ class EsmAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states,
@ -564,12 +543,11 @@ class EsmAttention(nn.Module):
hidden_states_ln = self.LayerNorm(hidden_states)
self_outputs = self.self(
hidden_states_ln,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -616,6 +594,7 @@ class EsmLayer(GradientCheckpointingLayer):
self.output = EsmOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states,
@ -626,25 +605,20 @@ class EsmLayer(GradientCheckpointingLayer):
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise AttributeError(
@ -652,31 +626,24 @@ class EsmLayer(GradientCheckpointingLayer):
" with cross-attention layers by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = self.feed_forward_chunk(attention_output)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
outputs = outputs + (None,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -694,6 +661,9 @@ class EsmEncoder(nn.Module):
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
@deprecate_kwarg("past_key_value", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states,
@ -707,38 +677,26 @@ class EsmEncoder(nn.Module):
output_hidden_states=False,
return_dict=True,
):
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
@ -750,21 +708,8 @@ class EsmEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
@ -863,6 +808,9 @@ class EsmModel(EsmPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -903,11 +851,6 @@ class EsmModel(EsmPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -921,11 +864,8 @@ class EsmModel(EsmPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if self.config._attn_implementation == "flash_attention_2":
extended_attention_mask = attention_mask
@ -958,7 +898,6 @@ class EsmModel(EsmPreTrainedModel):
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
@ -966,22 +905,16 @@ class EsmModel(EsmPreTrainedModel):
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
@ -1025,6 +958,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1058,7 +992,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
@ -1070,10 +1004,6 @@ class EsmForMaskedLM(EsmPreTrainedModel):
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
@ -1125,6 +1055,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1154,7 +1085,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
@ -1184,10 +1115,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
@ -1210,6 +1137,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1237,7 +1165,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -1252,10 +1180,6 @@ class EsmForTokenClassification(EsmPreTrainedModel):
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
@ -1283,7 +1207,7 @@ class EsmClassificationHead(nn.Module):
return x
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
@ -1295,7 +1219,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx

View File

@ -39,6 +39,7 @@ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and
from ...utils import (
ModelOutput,
auto_docstring,
can_return_tuple,
logging,
torch_int,
)
@ -770,6 +771,7 @@ class GitVisionEncoder(nn.Module):
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -836,8 +838,6 @@ class GitVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)

View File

@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_hubert import HubertConfig
@ -300,6 +301,7 @@ class HubertAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -307,7 +309,7 @@ class HubertAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -328,42 +330,9 @@ class HubertAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -385,7 +354,7 @@ class HubertAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class HubertFeedForward(nn.Module):

View File

@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import (
ModelOutput,
can_return_tuple,
logging,
)
from .configuration_idefics import IdeficsVisionConfig
@ -351,6 +352,7 @@ class IdeficsVisionEncoder(nn.Module):
self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -417,8 +419,6 @@ class IdeficsVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)

View File

@ -451,6 +451,7 @@ class Kosmos2VisionEncoder(nn.Module):
self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -517,8 +518,6 @@ class Kosmos2VisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)

View File

@ -14,6 +14,7 @@
# limitations under the License.
"""LayoutLM model configuration"""
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from typing import Any, Optional
@ -130,10 +131,22 @@ class LayoutLMConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self._position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.max_2d_position_embeddings = max_2d_position_embeddings
@property
def position_embedding_type(self):
warnings.warn(
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
FutureWarning,
)
return self._position_embedding_type
@position_embedding_type.setter
def position_embedding_type(self, value):
self._position_embedding_type = value
class LayoutLMOnnxConfig(OnnxConfig):
def __init__(

View File

@ -14,8 +14,7 @@
# limitations under the License.
"""PyTorch LayoutLM model."""
import math
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.utils.checkpoint
@ -25,16 +24,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_layoutlm import LayoutLMConfig
@ -120,9 +120,37 @@ class LayoutLMEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
# Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
class LayoutLMSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -130,6 +158,7 @@ class LayoutLMSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -139,20 +168,12 @@ class LayoutLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -162,96 +183,33 @@ class LayoutLMSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
head_mask=head_mask,
**kwargs,
)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
@ -270,18 +228,11 @@ class LayoutLMSelfOutput(nn.Module):
return hidden_states
LAYOUTLM_SELF_ATTENTION_CLASSES = {
"eager": LayoutLMSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
class LayoutLMAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.self = LayoutLMSelfAttention(config)
self.output = LayoutLMSelfOutput(config)
self.pruned_heads = set()
@ -303,6 +254,9 @@ class LayoutLMAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -312,15 +266,14 @@ class LayoutLMAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -358,22 +311,19 @@ class LayoutLMOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
class LayoutLMLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = LayoutLMAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
self.intermediate = LayoutLMIntermediate(config)
self.output = LayoutLMOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -383,60 +333,23 @@ class LayoutLMLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -445,14 +358,19 @@ class LayoutLMLayer(GradientCheckpointingLayer):
return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
class LayoutLMEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -465,65 +383,36 @@ class LayoutLMEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -648,6 +537,9 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -663,7 +555,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
) -> Union[tuple, BaseModelOutputWithPooling]:
r"""
bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
Bounding boxes of each input sequence tokens. Selected in the range `[0,
@ -756,20 +648,16 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@ -796,6 +684,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -871,11 +762,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -889,10 +778,6 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
labels.view(-1),
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
@ -921,6 +806,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
@can_return_tuple
@auto_docstring
def forward(
self,
@ -996,7 +882,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
pooled_output = outputs[1]
@ -1026,9 +912,6 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
@ -1059,6 +942,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1132,7 +1016,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -1145,10 +1029,6 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
@ -1176,6 +1056,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1253,7 +1134,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -1280,10 +1161,6 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,

View File

@ -14,6 +14,8 @@
# limitations under the License.
"""MarkupLM model configuration"""
import warnings
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@ -141,7 +143,7 @@ class MarkupLMConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self._position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
# additional properties
@ -152,5 +154,17 @@ class MarkupLMConfig(PretrainedConfig):
self.subs_pad_id = subs_pad_id
self.xpath_unit_hidden_size = xpath_unit_hidden_size
@property
def position_embedding_type(self):
warnings.warn(
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
FutureWarning,
)
return self._position_embedding_type
@position_embedding_type.setter
def position_embedding_type(self, value):
self._position_embedding_type = value
__all__ = ["MarkupLMConfig"]

View File

@ -14,9 +14,8 @@
# limitations under the License.
"""PyTorch MarkupLM model."""
import math
import os
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.utils.checkpoint
@ -26,20 +25,22 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, can_return_tuple, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_markuplm import MarkupLMConfig
@ -326,9 +327,37 @@ class MarkupLMOnlyMLMHead(nn.Module):
return prediction_scores
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM
# Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
class MarkupLMSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -336,6 +365,7 @@ class MarkupLMSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -345,20 +375,12 @@ class MarkupLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -368,111 +390,41 @@ class MarkupLMSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
head_mask=head_mask,
**kwargs,
)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in MarkupLMModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
MARKUPLM_SELF_ATTENTION_CLASSES = {
"eager": MarkupLMSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
class MarkupLMAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.self = MarkupLMSelfAttention(config)
self.output = MarkupLMSelfOutput(config)
self.pruned_heads = set()
@ -494,6 +446,9 @@ class MarkupLMAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -503,37 +458,33 @@ class MarkupLMAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
class MarkupLMLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = MarkupLMAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute")
self.intermediate = MarkupLMIntermediate(config)
self.output = MarkupLMOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -543,60 +494,23 @@ class MarkupLMLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -605,14 +519,19 @@ class MarkupLMLayer(GradientCheckpointingLayer):
return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
class MarkupLMEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -625,65 +544,36 @@ class MarkupLMEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -749,6 +639,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@can_return_tuple
@auto_docstring
def forward(
self,
@ -763,7 +654,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
) -> Union[tuple, BaseModelOutputWithPooling]:
r"""
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
Tag IDs for each token in the input sequence, padded up to config.max_depth.
@ -839,21 +730,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
# Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache
@ -879,6 +765,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -939,7 +826,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -966,10 +853,6 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
@ -1000,6 +883,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1058,7 +942,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = outputs[0]
@ -1072,10 +956,6 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
labels.view(-1),
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=prediction_scores,
@ -1107,6 +987,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
@ -1164,7 +1045,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
pooled_output = outputs[1]
@ -1194,9 +1075,6 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,

View File

@ -354,11 +354,6 @@ class MaskFormerSwinSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
@ -367,11 +362,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

View File

@ -182,7 +182,6 @@ def eager_attention_forward(
return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen
class MusicgenAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -189,7 +189,7 @@ def eager_attention_forward(
return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody
class MusicgenMelodyAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -503,7 +503,7 @@ def eager_attention_forward(
return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states
class NllbMoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_patchtsmixer import PatchTSMixerConfig
@ -303,6 +304,7 @@ class PatchTSMixerAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -310,7 +312,7 @@ class PatchTSMixerAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -331,42 +333,9 @@ class PatchTSMixerAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -388,7 +357,7 @@ class PatchTSMixerAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class PatchMixerBlock(nn.Module):

View File

@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_patchtst import PatchTSTConfig
@ -100,6 +101,7 @@ class PatchTSTAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -107,7 +109,7 @@ class PatchTSTAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -128,42 +130,9 @@ class PatchTSTAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -185,7 +154,7 @@ class PatchTSTAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class PatchTSTBatchNorm(nn.Module):

View File

@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_sew import SEWConfig
@ -293,6 +294,7 @@ class SEWAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -300,7 +302,7 @@ class SEWAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -321,42 +323,9 @@ class SEWAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -378,7 +347,7 @@ class SEWAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class SEWFeedForward(nn.Module):

View File

@ -205,7 +205,7 @@ def eager_attention_forward(
return attn_output, attn_weights
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->Speech2Text
class Speech2TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

View File

@ -14,9 +14,8 @@
# limitations under the License.
"""PyTorch Splinter model."""
import math
from dataclasses import dataclass
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.utils.checkpoint
@ -25,13 +24,19 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_outputs import (
BaseModelOutput,
ModelOutput,
QuestionAnsweringModelOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
auto_docstring,
can_return_tuple,
logging,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_splinter import SplinterConfig
@ -64,7 +69,6 @@ class SplinterEmbeddings(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: Optional[int] = 0,
) -> tuple:
if input_ids is not None:
input_shape = input_ids.size()
@ -74,7 +78,7 @@ class SplinterEmbeddings(nn.Module):
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
@ -92,9 +96,37 @@ class SplinterEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
# Copied from transformers.models.align.modeling_align.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter
class SplinterSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -102,6 +134,7 @@ class SplinterSelfAttention(nn.Module):
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@ -111,20 +144,12 @@ class SplinterSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.attention_dropout = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -134,96 +159,33 @@ class SplinterSelfAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
head_mask=head_mask,
**kwargs,
)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in SplinterModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
@ -242,18 +204,11 @@ class SplinterSelfOutput(nn.Module):
return hidden_states
SPLINTER_SELF_ATTENTION_CLASSES = {
"eager": SplinterSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter
class SplinterAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config):
super().__init__()
self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.self = SplinterSelfAttention(config)
self.output = SplinterSelfOutput(config)
self.pruned_heads = set()
@ -275,6 +230,9 @@ class SplinterAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -284,15 +242,14 @@ class SplinterAttention(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
@ -330,22 +287,19 @@ class SplinterOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter
class SplinterLayer(GradientCheckpointingLayer):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = SplinterAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
self.intermediate = SplinterIntermediate(config)
self.output = SplinterOutput(config)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -355,60 +309,23 @@ class SplinterLayer(GradientCheckpointingLayer):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
@ -417,14 +334,19 @@ class SplinterLayer(GradientCheckpointingLayer):
return layer_output
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter
class SplinterEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
def forward(
self,
hidden_states: torch.Tensor,
@ -437,65 +359,36 @@ class SplinterEncoder(nn.Module):
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
**kwargs,
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@ -554,6 +447,11 @@ class SplinterModel(SplinterPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
@deprecate_kwarg("past_key_values", version="4.54.0")
@deprecate_kwarg("use_cache", version="4.54.0")
@can_return_tuple
@auto_docstring
def forward(
self,
@ -570,7 +468,7 @@ class SplinterModel(SplinterPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[tuple, BaseModelOutput]:
r"""
token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
@ -592,11 +490,6 @@ class SplinterModel(SplinterPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -610,11 +503,8 @@ class SplinterModel(SplinterPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
@ -622,17 +512,6 @@ class SplinterModel(SplinterPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
@ -645,31 +524,21 @@ class SplinterModel(SplinterPreTrainedModel):
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
)
sequence_output = encoder_outputs[0]
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]
return BaseModelOutputWithPastAndCrossAttentions(
return BaseModelOutput(
last_hidden_state=sequence_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)

View File

@ -435,11 +435,6 @@ class SwinSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
@ -448,11 +443,11 @@ class SwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

View File

@ -45,6 +45,7 @@ from ...modeling_outputs import (
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_unispeech import UniSpeechConfig
@ -332,6 +333,7 @@ class UniSpeechAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -339,7 +341,7 @@ class UniSpeechAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -360,42 +362,9 @@ class UniSpeechAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -417,7 +386,7 @@ class UniSpeechAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class UniSpeechFeedForward(nn.Module):

View File

@ -47,6 +47,7 @@ from ...modeling_outputs import (
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_unispeech_sat import UniSpeechSatConfig
@ -337,6 +338,7 @@ class UniSpeechSatAttention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -344,7 +346,7 @@ class UniSpeechSatAttention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -365,42 +367,9 @@ class UniSpeechSatAttention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -422,7 +391,7 @@ class UniSpeechSatAttention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class UniSpeechSatFeedForward(nn.Module):

View File

@ -55,6 +55,7 @@ from ...utils import (
is_torch_flex_attn_available,
logging,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_wav2vec2 import Wav2Vec2Config
@ -524,6 +525,7 @@ class Wav2Vec2Attention(nn.Module):
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@deprecate_kwarg("past_key_value", version="4.54.0")
def forward(
self,
hidden_states: torch.Tensor,
@ -531,7 +533,7 @@ class Wav2Vec2Attention(nn.Module):
past_key_value: Optional[tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_attentions: Optional[bool] = False,
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
**kwargs: Unpack[FlashAttentionKwargs],
@ -552,42 +554,9 @@ class Wav2Vec2Attention(nn.Module):
# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
current_states = key_value_states if is_cross_attention else hidden_states
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@ -609,7 +578,7 @@ class Wav2Vec2Attention(nn.Module):
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, attn_weights, None
class Wav2Vec2FeedForward(nn.Module):

View File

@ -30,6 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
ModelOutput,
auto_docstring,
can_return_tuple,
logging,
torch_int,
)
@ -576,6 +577,7 @@ class XCLIPEncoder(nn.Module):
self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds,
@ -642,8 +644,6 @@ class XCLIPEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)

View File

@ -297,7 +297,7 @@ class AltCLIPTextModelTester:
@require_torch
class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (AltCLIPTextModel,) if is_torch_available() else ()
fx_compatible = True
fx_compatible = False # Cannot support if `can_return_tuple`
test_pruning = False
test_head_masking = False
@ -411,7 +411,7 @@ def prepare_img():
class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (AltCLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": AltCLIPModel} if is_torch_available() else {}
fx_compatible = True
fx_compatible = False # Cannot support if `can_return_tuple`
test_head_masking = False
test_pruning = False
test_resize_embeddings = False

View File

@ -243,7 +243,7 @@ class LayoutLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
if is_torch_available()
else {}
)
fx_compatible = True
fx_compatible = False # Cannot support if `can_return_tuple`
def setUp(self):
self.model_tester = LayoutLMModelTester(self)

View File

@ -372,6 +372,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
@unittest.skip(
"Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
"Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@require_torch
class SplinterModelIntegrationTest(unittest.TestCase):

View File

@ -276,6 +276,9 @@ SPECIAL_CASES_TO_ALLOW = {
"attention_chunk_size",
],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
# position_embedding_type not used and deprecated. Should be deleted in v4.55
"LayoutLMConfig": ["position_embedding_type"],
"MarkupLMConfig": ["position_embedding_type"],
"SmolLM3Config": ["no_rope_layer_interval"],
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
}