data2vectext, making it modular tomorrow
Some checks are pending
Secret Leaks / trufflehog (push) Waiting to run

This commit is contained in:
Vasqu 2025-07-01 18:41:29 +02:00
parent dd7aeca424
commit ad3ffe55a9
2 changed files with 463 additions and 216 deletions

View File

@ -14,16 +14,17 @@
# limitations under the License.
"""PyTorch Data2VecText model."""
import math
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...cache_utils import Cache, EncoderDecoderCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
@ -35,12 +36,16 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_data2vec_text import Data2VecTextConfig
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -48,16 +53,12 @@ _HIDDEN_STATES_START_POSITION = 2
# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText
class Data2VecTextForTextEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
class Data2VecTextEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
@ -73,21 +74,27 @@ class Data2VecTextForTextEmbeddings(nn.Module):
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
# End copy
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
position_ids = self.create_position_ids_from_input_ids(
input_ids, self.padding_idx, past_key_values_length
)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
if input_ids is not None:
input_shape = input_ids.size()
@ -119,7 +126,8 @@ class Data2VecTextForTextEmbeddings(nn.Module):
embeddings = self.dropout(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
@staticmethod
def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
@ -132,24 +140,100 @@ class Data2VecTextForTextEmbeddings(nn.Module):
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
@staticmethod
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))
# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText
class Data2VecTextSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
@ -164,114 +248,178 @@ class Data2VecTextSelfAttention(nn.Module):
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.is_causal = is_causal
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = tgt_len
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# get all proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*kv_input_shape).transpose(1, 2)
query_layer = self.transpose_for_scores(mixed_query_layer)
if past_key_value is not None:
# decoder-only roberta can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer,
value_layer,
self.layer_idx,
{"cache_position": cache_position},
)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
-1, 1
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
head_mask=head_mask,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
outputs = (
attn_output,
attn_weights,
)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
# Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->Data2VecText
class Data2VecTextCrossAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_causal = is_causal
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = encoder_hidden_states.shape[1]
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
if past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_value.cross_attention_cache.key_cache[self.layer_idx]
value_layer = past_key_value.cross_attention_cache.value_cache[self.layer_idx]
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all states to the cache
key_layer, value_layer = past_key_value.cross_attention_cache.update(
key_layer,
value_layer,
self.layer_idx,
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_value.is_updated[self.layer_idx] = True
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
encoder_attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
head_mask=head_mask,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
outputs = (
attn_output,
attn_weights,
)
outputs = outputs + (past_key_value,)
return outputs
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
class Data2VecTextSelfOutput(nn.Module):
def __init__(self, config):
@ -287,17 +435,16 @@ class Data2VecTextSelfOutput(nn.Module):
return hidden_states
DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
"eager": Data2VecTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
class Data2VecTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
def __init__(
self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False
):
super().__init__()
self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
self.is_cross_attention = is_cross_attention
attention_class = Data2VecTextCrossAttention if is_cross_attention else Data2VecTextSelfAttention
self.self = attention_class(
config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx
)
self.output = Data2VecTextSelfOutput(config)
self.pruned_heads = set()
@ -328,17 +475,27 @@ class Data2VecTextAttention(nn.Module):
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
if self.is_cross_attention:
self_outputs = self.self(
hidden_states,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
**kwargs,
)
else:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
past_key_value,
cache_position,
**kwargs,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
@ -377,17 +534,23 @@ class Data2VecTextOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
class Data2VecTextLayer(GradientCheckpointingLayer):
def __init__(self, config):
def __init__(self, config, layer_idx=None):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Data2VecTextAttention(config)
self.attention = Data2VecTextAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = Data2VecTextAttention(config, position_embedding_type="absolute")
self.crossattention = Data2VecTextAttention(
config,
position_embedding_type="absolute",
is_causal=False,
layer_idx=layer_idx,
is_cross_attention=True,
)
self.intermediate = Data2VecTextIntermediate(config)
self.output = Data2VecTextOutput(config)
@ -398,28 +561,25 @@ class Data2VecTextLayer(GradientCheckpointingLayer):
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
past_key_value=past_key_value,
cache_position=cache_position,
**kwargs,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
@ -427,24 +587,18 @@ class Data2VecTextLayer(GradientCheckpointingLayer):
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
None, # attention_mask
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
past_key_value=past_key_value,
**kwargs,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
@ -452,7 +606,7 @@ class Data2VecTextLayer(GradientCheckpointingLayer):
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
outputs = outputs + (past_key_value,)
return outputs
@ -467,8 +621,7 @@ class Data2VecTextEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([Data2VecTextLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
def forward(
self,
@ -477,30 +630,23 @@ class Data2VecTextEncoder(nn.Module):
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
cache_position: Optional[torch.Tensor] = None,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
next_decoder_cache = () if use_cache else None
next_decoder_cache = None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states,
@ -508,13 +654,13 @@ class Data2VecTextEncoder(nn.Module):
layer_head_mask,
encoder_hidden_states, # as a positional argument for gradient checkpointing
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
past_key_value=past_key_values,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
next_decoder_cache = layer_outputs[-1]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
@ -523,12 +669,14 @@ class Data2VecTextEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
next_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
@ -537,7 +685,7 @@ class Data2VecTextEncoder(nn.Module):
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
@ -566,6 +714,12 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
base_model_prefix = "data2vec_text"
supports_gradient_checkpointing = True
_no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
# Not too much usage so low prio to fix
_supports_static_cache = False
def _init_weights(self, module):
"""Initialize the weights"""
@ -610,8 +764,9 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
"""
super().__init__(config)
self.config = config
self.gradient_checkpointing = False
self.embeddings = Data2VecTextForTextEmbeddings(config)
self.embeddings = Data2VecTextEmbeddings(config)
self.encoder = Data2VecTextEncoder(config)
self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None
@ -644,11 +799,12 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -671,14 +827,29 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
return_legacy_cache = True
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
@ -688,20 +859,43 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
if attention_mask is None:
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
if self.config.is_decoder and encoder_hidden_states is not None and encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=device)
if self.config.is_decoder:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
else:
encoder_extended_attention_mask = None
attention_mask = self._update_full_mask(
attention_mask,
embedding_output,
)
if encoder_attention_mask is not None:
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
embedding_output.shape[:2],
embedding_output,
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -710,19 +904,12 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
@ -732,6 +919,9 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if return_legacy_cache:
encoder_outputs.past_key_values = encoder_outputs.past_key_values.to_legacy_cache()
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
@ -744,6 +934,65 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
cross_attentions=encoder_outputs.cross_attentions,
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
return encoder_attention_mask
@auto_docstring(
custom_intro="""
@ -788,6 +1037,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
r"""
@ -830,6 +1080,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
sequence_output = outputs[0]
@ -1351,22 +1602,6 @@ class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
__all__ = [
"Data2VecTextForCausalLM",
"Data2VecTextForMaskedLM",

View File

@ -37,10 +37,7 @@ if is_torch_available():
Data2VecTextForTokenClassification,
Data2VecTextModel,
)
from transformers.models.data2vec.modeling_data2vec_text import (
Data2VecTextForTextEmbeddings,
create_position_ids_from_input_ids,
)
from transformers.models.data2vec.modeling_data2vec_text import Data2VecTextEmbeddings
class Data2VecTextModelTester:
@ -387,6 +384,12 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
)
model_split_percents = [0.5, 0.9]
# Overwriting to add `is_decoder` flag
def prepare_config_and_inputs_for_generate(self, batch_size=2):
config, inputs = super().prepare_config_and_inputs_for_generate(batch_size)
config.is_decoder = True
return config, inputs
def setUp(self):
self.model_tester = Data2VecTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Data2VecTextConfig, hidden_size=37)
@ -402,6 +405,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
config_and_inputs[0]._attn_implementation = "eager"
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_as_decoder(self):
@ -446,6 +450,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
def test_decoder_model_past_with_large_inputs_relative_pos_emb(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
config_and_inputs[0].position_embedding_type = "relative_key"
config_and_inputs[0]._attn_implementation = "eager"
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_masked_lm(self):
@ -477,14 +482,14 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
first available non-padding position index is Data2VecTextForTextEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = Data2VecTextForTextEmbeddings(config=config)
model = Data2VecTextEmbeddings(config=config)
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
expected_positions = torch.as_tensor(
[[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
)
position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
position_ids = Data2VecTextEmbeddings.create_position_ids_from_input_ids(input_ids, model.padding_idx)
self.assertEqual(position_ids.shape, expected_positions.shape)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
@ -495,7 +500,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
first available non-padding position index is Data2VecTextForTextEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
embeddings = Data2VecTextForTextEmbeddings(config=config)
embeddings = Data2VecTextEmbeddings(config=config)
inputs_embeds = torch.empty(2, 4, 30)
expected_single_positions = [
@ -505,10 +510,17 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
3 + embeddings.padding_idx + 1,
]
expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds, embeddings.padding_idx)
self.assertEqual(position_ids.shape, expected_positions.shape)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
@unittest.skip("Data2VecText token type ids does not work with the flash attention position ids")
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Data2VecText token type ids does not work with the flash attention position ids")
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
pass
@require_torch
class Data2VecTextModelIntegrationTest(TestCasePlus):