mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
🚨 Don't use cache in non-generative models (#38751)
* deprecate for 1 version * style * fix some tests * fix esm * skip for now, GC requires positional args but we have keyword args * remove transpose for scores in modified models only * skip fx trace tests
This commit is contained in:
parent
dbc98328da
commit
e435574721
@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,14 +25,15 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithNoAttention,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
|
||||
|
||||
|
||||
@ -90,7 +91,7 @@ class AlignOutput(ModelOutput):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The output of [`AlignVisionModel`].
|
||||
text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
||||
text_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`AlignTextModel`].
|
||||
vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
|
||||
The output of the [`AlignVisionModel`].
|
||||
@ -101,7 +102,7 @@ class AlignOutput(ModelOutput):
|
||||
logits_per_text: Optional[torch.FloatTensor] = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
|
||||
|
||||
def to_tuple(self) -> tuple[Any]:
|
||||
@ -508,7 +509,6 @@ class AlignVisionEncoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText
|
||||
class AlignTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -537,7 +537,6 @@ class AlignTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -547,7 +546,7 @@ class AlignTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -573,9 +572,35 @@ class AlignTextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AlignTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -583,6 +608,7 @@ class AlignTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -592,20 +618,12 @@ class AlignTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -615,96 +633,33 @@ class AlignTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -723,18 +678,10 @@ class AlignTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": AlignTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
|
||||
class AlignTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = AlignTextSelfAttention(config)
|
||||
self.output = AlignTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -756,6 +703,9 @@ class AlignTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -765,15 +715,14 @@ class AlignTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -811,22 +760,18 @@ class AlignTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
|
||||
class AlignTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AlignTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = AlignTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = AlignTextIntermediate(config)
|
||||
self.output = AlignTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -836,60 +781,23 @@ class AlignTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -898,14 +806,18 @@ class AlignTextLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText
|
||||
class AlignTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -918,65 +830,36 @@ class AlignTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1052,6 +935,7 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1059,12 +943,13 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Examples:
|
||||
|
||||
@ -1133,20 +1018,17 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1180,6 +1062,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.convolution
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1219,7 +1102,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
# Apply pooling
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
@ -1227,9 +1110,6 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
# Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
|
||||
pooled_output = pooled_output.reshape(pooled_output.shape[:2])
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndNoAttention(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1369,6 +1249,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1419,7 +1300,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1431,7 +1312,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
@ -1450,10 +1331,6 @@ class AlignModel(AlignPreTrainedModel):
|
||||
if return_loss:
|
||||
loss = align_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return AlignOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
|
@ -26,14 +26,14 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndProjection,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
|
||||
|
||||
|
||||
@ -180,7 +180,6 @@ class AltRobertaEmbeddings(nn.Module):
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta
|
||||
class AltRobertaSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
@ -206,13 +205,9 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -223,54 +218,18 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
@ -310,8 +269,6 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -335,7 +292,6 @@ ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
|
||||
class AltRobertaAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
@ -363,6 +319,9 @@ class AltRobertaAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -375,12 +334,9 @@ class AltRobertaAttention(nn.Module):
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -418,22 +374,19 @@ class AltRobertaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta
|
||||
class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AltRobertaAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = AltRobertaIntermediate(config)
|
||||
self.output = AltRobertaOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -443,60 +396,23 @@ class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -505,14 +421,19 @@ class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta
|
||||
class AltRobertaEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -525,65 +446,36 @@ class AltRobertaEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -787,6 +679,7 @@ class AltCLIPEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -853,8 +746,6 @@ class AltCLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -1008,6 +899,7 @@ class AltCLIPVisionTransformer(nn.Module):
|
||||
self.encoder = AltCLIPEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1033,16 +925,13 @@ class AltCLIPVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1106,16 +995,11 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
||||
The model behaves as an encoder following the architecture described in *Attention is
|
||||
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
|
||||
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
|
||||
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
"""
|
||||
)
|
||||
class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
@ -1152,6 +1036,10 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
@ -1176,11 +1064,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1194,11 +1077,8 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -1212,21 +1092,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@ -1235,33 +1100,23 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1284,6 +1139,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
|
||||
return super().resize_token_embeddings(new_num_tokens)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1326,11 +1184,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# last module outputs
|
||||
@ -1343,9 +1199,6 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
projection_state = self.transformation(sequence_output)
|
||||
pooler_output = projection_state[:, 0]
|
||||
|
||||
if not return_dict:
|
||||
return (projection_state, pooler_output) + outputs[2:4]
|
||||
|
||||
return BaseModelOutputWithPoolingAndProjection(
|
||||
last_hidden_state=projection_state,
|
||||
pooler_output=pooler_output,
|
||||
|
@ -1026,7 +1026,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -26,13 +26,14 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_bros import BrosConfig
|
||||
|
||||
|
||||
@ -150,7 +151,6 @@ class BrosTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -160,7 +160,7 @@ class BrosTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self, "token_type_ids"):
|
||||
@ -208,14 +208,7 @@ class BrosSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor):
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -227,42 +220,21 @@ class BrosSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[torch.Tensor] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
if is_cross_attention:
|
||||
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
@ -317,7 +289,7 @@ class BrosSelfAttention(nn.Module):
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -364,6 +336,7 @@ class BrosAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -382,7 +355,6 @@ class BrosAttention(nn.Module):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
@ -435,6 +407,7 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
self.intermediate = BrosIntermediate(config)
|
||||
self.output = BrosOutput(config)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -446,50 +419,38 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if hasattr(self, "crossattention"):
|
||||
raise Exception(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
@ -500,7 +461,7 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -516,6 +477,9 @@ class BrosEncoder(nn.Module):
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -529,33 +493,28 @@ class BrosEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
bbox_pos_emb,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
hidden_states=hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -564,21 +523,8 @@ class BrosEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -689,6 +635,9 @@ class BrosModel(BrosPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -736,11 +685,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -756,9 +700,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
|
||||
@ -797,7 +738,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
||||
@ -813,22 +753,16 @@ class BrosModel(BrosPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
@ -852,6 +786,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -908,7 +843,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -927,10 +862,6 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -976,6 +907,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1037,7 +969,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
@ -1082,10 +1014,6 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
loss = initial_token_loss + subsequent_token_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (initial_token_logits, subsequent_token_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return BrosSpadeOutput(
|
||||
loss=loss,
|
||||
initial_token_logits=initial_token_logits,
|
||||
@ -1118,6 +1046,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1173,7 +1102,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
@ -1203,10 +1132,6 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Chinese-CLIP model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,13 +25,13 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
|
||||
|
||||
|
||||
@ -90,7 +89,7 @@ class ChineseCLIPOutput(ModelOutput):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP
|
||||
class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -119,7 +118,6 @@ class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -129,7 +127,7 @@ class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -239,9 +237,37 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP
|
||||
class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -249,6 +275,7 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -258,20 +285,12 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -281,96 +300,33 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -389,18 +345,11 @@ class ChineseCLIPTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ChineseCLIPTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP
|
||||
class ChineseCLIPTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = ChineseCLIPTextSelfAttention(config)
|
||||
self.output = ChineseCLIPTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -422,6 +371,9 @@ class ChineseCLIPTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -431,15 +383,14 @@ class ChineseCLIPTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -468,66 +419,37 @@ class ChineseCLIPVisionAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
None,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
|
||||
@ -577,22 +499,19 @@ class ChineseCLIPVisionMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP
|
||||
class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ChineseCLIPTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ChineseCLIPTextIntermediate(config)
|
||||
self.output = ChineseCLIPTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -602,60 +521,23 @@ class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -777,14 +659,19 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
|
||||
class ChineseCLIPTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -797,65 +684,36 @@ class ChineseCLIPTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -874,6 +732,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -922,8 +781,6 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -940,6 +797,7 @@ class ChineseCLIPVisionTransformer(nn.Module):
|
||||
self.encoder = ChineseCLIPVisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -965,16 +823,13 @@ class ChineseCLIPVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1034,6 +889,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1050,18 +906,13 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1093,56 +944,28 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1343,6 +1166,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1392,7 +1216,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1402,7 +1226,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
@ -1424,14 +1248,6 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
if return_loss:
|
||||
loss = chinese_clip_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
# fix the None pooled_output of text_outputs to conform with dict_output
|
||||
pooled_output = text_outputs[1]
|
||||
if pooled_output is None:
|
||||
text_outputs = (text_outputs[0],) + text_outputs[2:]
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ChineseCLIPOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import collections
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -26,13 +26,14 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
|
||||
|
||||
|
||||
@ -399,11 +400,6 @@ class ClapAudioSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -412,11 +408,11 @@ class ClapAudioSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
@ -1090,9 +1086,37 @@ class ClapTextEmbeddings(nn.Module):
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
|
||||
class ClapTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -1100,6 +1124,7 @@ class ClapTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -1109,20 +1134,12 @@ class ClapTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1132,96 +1149,33 @@ class ClapTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -1240,18 +1194,11 @@ class ClapTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
CLAP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ClapTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
|
||||
class ClapTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = ClapTextSelfAttention(config)
|
||||
self.output = ClapTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -1273,6 +1220,9 @@ class ClapTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1282,15 +1232,14 @@ class ClapTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -1328,22 +1277,19 @@ class ClapTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
|
||||
class ClapTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ClapTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ClapTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ClapTextIntermediate(config)
|
||||
self.output = ClapTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1353,60 +1299,23 @@ class ClapTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -1415,14 +1324,19 @@ class ClapTextLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
|
||||
class ClapTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1435,65 +1349,36 @@ class ClapTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1643,6 +1528,11 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1666,11 +1556,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1684,11 +1569,8 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -1702,21 +1584,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@ -1725,33 +1592,23 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1892,6 +1749,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
|
||||
return audio_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1947,7 +1805,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1956,7 +1814,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
|
||||
@ -1981,10 +1839,6 @@ class ClapModel(ClapPreTrainedModel):
|
||||
audio_loss = contrastive_loss(logits_per_audio.t())
|
||||
loss = (caption_loss + audio_loss) / 2.0
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ClapOutput(
|
||||
loss=loss,
|
||||
logits_per_audio=logits_per_audio,
|
||||
@ -2013,6 +1867,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.embeddings.word_embeddings = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2045,17 +1900,13 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
|
||||
|
||||
text_embeds = self.text_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return ClapTextModelOutput(
|
||||
text_embeds=text_embeds,
|
||||
last_hidden_state=text_outputs.last_hidden_state,
|
||||
@ -2079,6 +1930,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.audio_model.audio_encoder.patch_embed.proj
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2123,17 +1975,13 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
|
||||
|
||||
audio_embeds = self.audio_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return ClapAudioModelOutput(
|
||||
audio_embeds=audio_embeds,
|
||||
last_hidden_state=audio_outputs.last_hidden_state,
|
||||
|
@ -28,7 +28,7 @@ from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
|
||||
|
||||
|
||||
@ -490,6 +490,7 @@ class CLIPSegEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -555,8 +556,6 @@ class CLIPSegEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -45,6 +45,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_data2vec_audio import Data2VecAudioConfig
|
||||
|
||||
|
||||
@ -240,6 +241,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -247,7 +249,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -268,42 +270,9 @@ class Data2VecAudioAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -325,7 +294,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Data2VecAudioFeedForward(nn.Module):
|
||||
|
@ -634,7 +634,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -405,11 +405,6 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -418,11 +413,11 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -26,14 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_esm import EsmConfig
|
||||
|
||||
|
||||
@ -187,12 +188,16 @@ class EsmEmbeddings(nn.Module):
|
||||
self.mask_token_id = config.mask_token_id
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
@ -281,11 +286,7 @@ class EsmSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -296,32 +297,22 @@ class EsmSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
||||
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
if is_cross_attention:
|
||||
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
||||
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
||||
@ -329,16 +320,6 @@ class EsmSelfAttention(nn.Module):
|
||||
# ESM code and fix rotary embeddings.
|
||||
query_layer = query_layer * self.attention_head_size**-0.5
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
if self.position_embedding_type == "rotary":
|
||||
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
||||
|
||||
@ -385,7 +366,7 @@ class EsmSelfAttention(nn.Module):
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -418,6 +399,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -441,7 +423,6 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
@ -450,9 +431,6 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
if past_key_value is not None:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
@ -514,7 +492,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
|
||||
outputs = (attn_output, None)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -551,6 +529,7 @@ class EsmAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -564,12 +543,11 @@ class EsmAttention(nn.Module):
|
||||
hidden_states_ln = self.LayerNorm(hidden_states)
|
||||
self_outputs = self.self(
|
||||
hidden_states_ln,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -616,6 +594,7 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
self.output = EsmOutput(config)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -626,25 +605,20 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise AttributeError(
|
||||
@ -652,31 +626,24 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
" with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = self.feed_forward_chunk(attention_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -694,6 +661,9 @@ class EsmEncoder(nn.Module):
|
||||
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -707,38 +677,26 @@ class EsmEncoder(nn.Module):
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -750,21 +708,8 @@ class EsmEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -863,6 +808,9 @@ class EsmModel(EsmPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -903,11 +851,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -921,11 +864,8 @@ class EsmModel(EsmPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
extended_attention_mask = attention_mask
|
||||
@ -958,7 +898,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -966,22 +905,16 @@ class EsmModel(EsmPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
@ -1025,6 +958,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1058,7 +992,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
@ -1070,10 +1004,6 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
labels = labels.to(prediction_scores.device)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
@ -1125,6 +1055,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1154,7 +1085,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
@ -1184,10 +1115,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1210,6 +1137,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1237,7 +1165,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1252,10 +1180,6 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1283,7 +1207,7 @@ class EsmClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
@ -1295,7 +1219,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
@ -770,6 +771,7 @@ class GitVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -836,8 +838,6 @@ class GitVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_hubert import HubertConfig
|
||||
|
||||
|
||||
@ -300,6 +301,7 @@ class HubertAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -307,7 +309,7 @@ class HubertAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -328,42 +330,9 @@ class HubertAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -385,7 +354,7 @@ class HubertAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class HubertFeedForward(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from .configuration_idefics import IdeficsVisionConfig
|
||||
@ -351,6 +352,7 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -417,8 +419,6 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -451,6 +451,7 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -517,8 +518,6 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""LayoutLM model configuration"""
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
@ -130,10 +131,22 @@ class LayoutLMConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self._position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
self.max_2d_position_embeddings = max_2d_position_embeddings
|
||||
|
||||
@property
|
||||
def position_embedding_type(self):
|
||||
warnings.warn(
|
||||
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._position_embedding_type
|
||||
|
||||
@position_embedding_type.setter
|
||||
def position_embedding_type(self, value):
|
||||
self._position_embedding_type = value
|
||||
|
||||
|
||||
class LayoutLMOnnxConfig(OnnxConfig):
|
||||
def __init__(
|
||||
|
@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch LayoutLM model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,16 +24,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_layoutlm import LayoutLMConfig
|
||||
|
||||
|
||||
@ -120,9 +120,37 @@ class LayoutLMEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
|
||||
class LayoutLMSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -130,6 +158,7 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -139,20 +168,12 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -162,96 +183,33 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -270,18 +228,11 @@ class LayoutLMSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
LAYOUTLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": LayoutLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
|
||||
class LayoutLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = LayoutLMSelfAttention(config)
|
||||
self.output = LayoutLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -303,6 +254,9 @@ class LayoutLMAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -312,15 +266,14 @@ class LayoutLMAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -358,22 +311,19 @@ class LayoutLMOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
|
||||
class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = LayoutLMAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = LayoutLMIntermediate(config)
|
||||
self.output = LayoutLMOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -383,60 +333,23 @@ class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -445,14 +358,19 @@ class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
|
||||
class LayoutLMEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -465,65 +383,36 @@ class LayoutLMEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -648,6 +537,9 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -663,7 +555,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
|
||||
Bounding boxes of each input sequence tokens. Selected in the range `[0,
|
||||
@ -756,20 +648,16 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -796,6 +684,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -871,11 +762,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -889,10 +778,6 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
labels.view(-1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
@ -921,6 +806,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -996,7 +882,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@ -1026,9 +912,6 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
@ -1059,6 +942,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1132,7 +1016,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1145,10 +1029,6 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1176,6 +1056,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1253,7 +1134,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1280,10 +1161,6 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
|
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""MarkupLM model configuration"""
|
||||
|
||||
import warnings
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@ -141,7 +143,7 @@ class MarkupLMConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self._position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
self.classifier_dropout = classifier_dropout
|
||||
# additional properties
|
||||
@ -152,5 +154,17 @@ class MarkupLMConfig(PretrainedConfig):
|
||||
self.subs_pad_id = subs_pad_id
|
||||
self.xpath_unit_hidden_size = xpath_unit_hidden_size
|
||||
|
||||
@property
|
||||
def position_embedding_type(self):
|
||||
warnings.warn(
|
||||
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._position_embedding_type
|
||||
|
||||
@position_embedding_type.setter
|
||||
def position_embedding_type(self, value):
|
||||
self._position_embedding_type = value
|
||||
|
||||
|
||||
__all__ = ["MarkupLMConfig"]
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch MarkupLM model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,20 +25,22 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
ALL_ATTENTION_FUNCTIONS,
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_markuplm import MarkupLMConfig
|
||||
|
||||
|
||||
@ -326,9 +327,37 @@ class MarkupLMOnlyMLMHead(nn.Module):
|
||||
return prediction_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
|
||||
class MarkupLMSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -336,6 +365,7 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -345,20 +375,12 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -368,111 +390,41 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in MarkupLMModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
MARKUPLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": MarkupLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
|
||||
class MarkupLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = MarkupLMSelfAttention(config)
|
||||
self.output = MarkupLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -494,6 +446,9 @@ class MarkupLMAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -503,37 +458,33 @@ class MarkupLMAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
|
||||
class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = MarkupLMAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = MarkupLMIntermediate(config)
|
||||
self.output = MarkupLMOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -543,60 +494,23 @@ class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -605,14 +519,19 @@ class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
|
||||
class MarkupLMEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -625,65 +544,36 @@ class MarkupLMEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -749,6 +639,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -763,7 +654,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
||||
Tag IDs for each token in the input sequence, padded up to config.max_depth.
|
||||
@ -839,21 +730,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache
|
||||
@ -879,6 +765,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -939,7 +826,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -966,10 +853,6 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
@ -1000,6 +883,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1058,7 +942,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1072,10 +956,6 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
labels.view(-1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=prediction_scores,
|
||||
@ -1107,6 +987,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1164,7 +1045,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@ -1194,9 +1075,6 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
|
@ -354,11 +354,6 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -367,11 +362,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -182,7 +182,6 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen
|
||||
class MusicgenAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -189,7 +189,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody
|
||||
class MusicgenMelodyAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -503,7 +503,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states
|
||||
class NllbMoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_patchtsmixer import PatchTSMixerConfig
|
||||
|
||||
|
||||
@ -303,6 +304,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -310,7 +312,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -331,42 +333,9 @@ class PatchTSMixerAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -388,7 +357,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class PatchMixerBlock(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_patchtst import PatchTSTConfig
|
||||
|
||||
|
||||
@ -100,6 +101,7 @@ class PatchTSTAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -107,7 +109,7 @@ class PatchTSTAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -128,42 +130,9 @@ class PatchTSTAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -185,7 +154,7 @@ class PatchTSTAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class PatchTSTBatchNorm(nn.Module):
|
||||
|
@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_sew import SEWConfig
|
||||
|
||||
|
||||
@ -293,6 +294,7 @@ class SEWAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -300,7 +302,7 @@ class SEWAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -321,42 +323,9 @@ class SEWAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -378,7 +347,7 @@ class SEWAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class SEWFeedForward(nn.Module):
|
||||
|
@ -205,7 +205,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->Speech2Text
|
||||
class Speech2TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Splinter model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,13 +24,19 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
ModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_splinter import SplinterConfig
|
||||
|
||||
|
||||
@ -64,7 +69,6 @@ class SplinterEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: Optional[int] = 0,
|
||||
) -> tuple:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -74,7 +78,7 @@ class SplinterEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
@ -92,9 +96,37 @@ class SplinterEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter
|
||||
class SplinterSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -102,6 +134,7 @@ class SplinterSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -111,20 +144,12 @@ class SplinterSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -134,96 +159,33 @@ class SplinterSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in SplinterModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -242,18 +204,11 @@ class SplinterSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
SPLINTER_SELF_ATTENTION_CLASSES = {
|
||||
"eager": SplinterSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter
|
||||
class SplinterAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = SplinterSelfAttention(config)
|
||||
self.output = SplinterSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -275,6 +230,9 @@ class SplinterAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -284,15 +242,14 @@ class SplinterAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -330,22 +287,19 @@ class SplinterOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter
|
||||
class SplinterLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = SplinterAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = SplinterIntermediate(config)
|
||||
self.output = SplinterOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -355,60 +309,23 @@ class SplinterLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -417,14 +334,19 @@ class SplinterLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter
|
||||
class SplinterEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -437,65 +359,36 @@ class SplinterEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -554,6 +447,11 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -570,7 +468,7 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
||||
@ -592,11 +490,6 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -610,11 +503,8 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
@ -622,17 +512,6 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@ -645,31 +524,21 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -435,11 +435,6 @@ class SwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -448,11 +443,11 @@ class SwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -45,6 +45,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_unispeech import UniSpeechConfig
|
||||
|
||||
|
||||
@ -332,6 +333,7 @@ class UniSpeechAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -339,7 +341,7 @@ class UniSpeechAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -360,42 +362,9 @@ class UniSpeechAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -417,7 +386,7 @@ class UniSpeechAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class UniSpeechFeedForward(nn.Module):
|
||||
|
@ -47,6 +47,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_unispeech_sat import UniSpeechSatConfig
|
||||
|
||||
|
||||
@ -337,6 +338,7 @@ class UniSpeechSatAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -344,7 +346,7 @@ class UniSpeechSatAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -365,42 +367,9 @@ class UniSpeechSatAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -422,7 +391,7 @@ class UniSpeechSatAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class UniSpeechSatFeedForward(nn.Module):
|
||||
|
@ -55,6 +55,7 @@ from ...utils import (
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
|
||||
|
||||
@ -524,6 +525,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -531,7 +533,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -552,42 +554,9 @@ class Wav2Vec2Attention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -609,7 +578,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Wav2Vec2FeedForward(nn.Module):
|
||||
|
@ -30,6 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
@ -576,6 +577,7 @@ class XCLIPEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -642,8 +644,6 @@ class XCLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -297,7 +297,7 @@ class AltCLIPTextModelTester:
|
||||
@require_torch
|
||||
class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (AltCLIPTextModel,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Cannot support if `can_return_tuple`
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
@ -411,7 +411,7 @@ def prepare_img():
|
||||
class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (AltCLIPModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": AltCLIPModel} if is_torch_available() else {}
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Cannot support if `can_return_tuple`
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -243,7 +243,7 @@ class LayoutLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Cannot support if `can_return_tuple`
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LayoutLMModelTester(self)
|
||||
|
@ -372,6 +372,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@unittest.skip(
|
||||
"Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class SplinterModelIntegrationTest(unittest.TestCase):
|
||||
|
@ -276,6 +276,9 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"attention_chunk_size",
|
||||
],
|
||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||
# position_embedding_type not used and deprecated. Should be deleted in v4.55
|
||||
"LayoutLMConfig": ["position_embedding_type"],
|
||||
"MarkupLMConfig": ["position_embedding_type"],
|
||||
"SmolLM3Config": ["no_rope_layer_interval"],
|
||||
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user