mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
data2vectext, making it modular tomorrow
Some checks are pending
Secret Leaks / trufflehog (push) Waiting to run
Some checks are pending
Secret Leaks / trufflehog (push) Waiting to run
This commit is contained in:
parent
dd7aeca424
commit
ad3ffe55a9
@ -14,16 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Data2VecText model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@ -35,12 +36,16 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from .configuration_data2vec_text import Data2VecTextConfig
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -48,16 +53,12 @@ _HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText
|
||||
class Data2VecTextForTextEmbeddings(nn.Module):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
"""
|
||||
class Data2VecTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
@ -73,21 +74,27 @@ class Data2VecTextForTextEmbeddings(nn.Module):
|
||||
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||
)
|
||||
|
||||
# End copy
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
||||
position_ids = self.create_position_ids_from_input_ids(
|
||||
input_ids, self.padding_idx, past_key_values_length
|
||||
)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -119,7 +126,8 @@ class Data2VecTextForTextEmbeddings(nn.Module):
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
@staticmethod
|
||||
def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
|
||||
"""
|
||||
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
||||
|
||||
@ -132,24 +140,100 @@ class Data2VecTextForTextEmbeddings(nn.Module):
|
||||
sequence_length = input_shape[1]
|
||||
|
||||
position_ids = torch.arange(
|
||||
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
||||
padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
@staticmethod
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3))
|
||||
|
||||
# Relative positional embeddings
|
||||
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query.shape[2], key.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
|
||||
|
||||
if module.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
|
||||
attn_weights = attn_weights + relative_position_scores
|
||||
elif module.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
|
||||
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
# Scaling is shifted in case of embeddings being relative
|
||||
attn_weights = attn_weights * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText
|
||||
class Data2VecTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
@ -164,114 +248,178 @@ class Data2VecTextSelfAttention(nn.Module):
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
# determine input shapes
|
||||
bsz, tgt_len = hidden_states.shape[:-1]
|
||||
src_len = tgt_len
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
|
||||
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
# get all proj
|
||||
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
if past_key_value is not None:
|
||||
# decoder-only roberta can have a simple dynamic cache for example
|
||||
current_past_key_value = past_key_value
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
current_past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
key_layer, value_layer = current_past_key_value.update(
|
||||
key_layer,
|
||||
value_layer,
|
||||
self.layer_idx,
|
||||
{"cache_position": cache_position},
|
||||
)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError(
|
||||
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
|
||||
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
# only for relevant for non-absolute positional embeddings
|
||||
use_cache=past_key_value is not None,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
|
||||
outputs = (
|
||||
attn_output,
|
||||
attn_weights,
|
||||
)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->Data2VecText
|
||||
class Data2VecTextCrossAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# determine input shapes
|
||||
bsz, tgt_len = hidden_states.shape[:-1]
|
||||
src_len = encoder_hidden_states.shape[1]
|
||||
|
||||
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
|
||||
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
|
||||
|
||||
# get query proj
|
||||
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx) if past_key_value is not None else False
|
||||
if past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value.cross_attention_cache.key_cache[self.layer_idx]
|
||||
value_layer = past_key_value.cross_attention_cache.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
# save all states to the cache
|
||||
key_layer, value_layer = past_key_value.cross_attention_cache.update(
|
||||
key_layer,
|
||||
value_layer,
|
||||
self.layer_idx,
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.position_embedding_type != "absolute":
|
||||
raise ValueError(
|
||||
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
|
||||
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
|
||||
)
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
encoder_attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
# only for relevant for non-absolute positional embeddings
|
||||
use_cache=past_key_value is not None,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
|
||||
outputs = (
|
||||
attn_output,
|
||||
attn_weights,
|
||||
)
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
||||
class Data2VecTextSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -287,17 +435,16 @@ class Data2VecTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": Data2VecTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
|
||||
class Data2VecTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(
|
||||
self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
self.is_cross_attention = is_cross_attention
|
||||
attention_class = Data2VecTextCrossAttention if is_cross_attention else Data2VecTextSelfAttention
|
||||
self.self = attention_class(
|
||||
config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx
|
||||
)
|
||||
self.output = Data2VecTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
@ -328,17 +475,27 @@ class Data2VecTextAttention(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
if self.is_cross_attention:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
past_key_value,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
@ -377,17 +534,23 @@ class Data2VecTextOutput(nn.Module):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
|
||||
class Data2VecTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Data2VecTextAttention(config)
|
||||
self.attention = Data2VecTextAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = Data2VecTextAttention(config, position_embedding_type="absolute")
|
||||
self.crossattention = Data2VecTextAttention(
|
||||
config,
|
||||
position_embedding_type="absolute",
|
||||
is_causal=False,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
self.intermediate = Data2VecTextIntermediate(config)
|
||||
self.output = Data2VecTextOutput(config)
|
||||
|
||||
@ -398,28 +561,25 @@ class Data2VecTextLayer(GradientCheckpointingLayer):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
@ -427,24 +587,18 @@ class Data2VecTextLayer(GradientCheckpointingLayer):
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
None, # attention_mask
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
past_key_value=past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
@ -452,7 +606,7 @@ class Data2VecTextLayer(GradientCheckpointingLayer):
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -467,8 +621,7 @@ class Data2VecTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([Data2VecTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -477,30 +630,23 @@ class Data2VecTextEncoder(nn.Module):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
@ -508,13 +654,13 @@ class Data2VecTextEncoder(nn.Module):
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
next_decoder_cache = layer_outputs[-1]
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -523,12 +669,14 @@ class Data2VecTextEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
@ -537,7 +685,7 @@ class Data2VecTextEncoder(nn.Module):
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -566,6 +714,12 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "data2vec_text"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
# Not too much usage so low prio to fix
|
||||
_supports_static_cache = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -610,8 +764,9 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.embeddings = Data2VecTextForTextEmbeddings(config)
|
||||
self.embeddings = Data2VecTextEmbeddings(config)
|
||||
self.encoder = Data2VecTextEncoder(config)
|
||||
|
||||
self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None
|
||||
@ -644,11 +799,12 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -671,14 +827,29 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
return_legacy_cache = True
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -688,20 +859,43 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
if attention_mask is None:
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
|
||||
|
||||
if self.config.is_decoder and encoder_hidden_states is not None and encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=device)
|
||||
|
||||
if self.config.is_decoder:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
attention_mask = self._update_full_mask(
|
||||
attention_mask,
|
||||
embedding_output,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = self._update_cross_attn_mask(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
embedding_output.shape[:2],
|
||||
embedding_output,
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -710,19 +904,12 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -732,6 +919,9 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if return_legacy_cache:
|
||||
encoder_outputs.past_key_values = encoder_outputs.past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
@ -744,6 +934,65 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, None],
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
if attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
|
||||
def _update_cross_attn_mask(
|
||||
self,
|
||||
encoder_hidden_states: Union[torch.Tensor, None],
|
||||
encoder_attention_mask: Union[torch.Tensor, None],
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask,
|
||||
inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1],
|
||||
)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(encoder_attention_mask, torch.Tensor):
|
||||
encoder_attention_mask = make_flex_block_causal_mask(
|
||||
encoder_attention_mask,
|
||||
query_length=input_shape[-1],
|
||||
is_causal=False,
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
return encoder_attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -788,6 +1037,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
@ -830,6 +1080,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1351,22 +1602,6 @@ class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor x:
|
||||
|
||||
Returns: torch.Tensor
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Data2VecTextForCausalLM",
|
||||
"Data2VecTextForMaskedLM",
|
||||
|
@ -37,10 +37,7 @@ if is_torch_available():
|
||||
Data2VecTextForTokenClassification,
|
||||
Data2VecTextModel,
|
||||
)
|
||||
from transformers.models.data2vec.modeling_data2vec_text import (
|
||||
Data2VecTextForTextEmbeddings,
|
||||
create_position_ids_from_input_ids,
|
||||
)
|
||||
from transformers.models.data2vec.modeling_data2vec_text import Data2VecTextEmbeddings
|
||||
|
||||
|
||||
class Data2VecTextModelTester:
|
||||
@ -387,6 +384,12 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
)
|
||||
model_split_percents = [0.5, 0.9]
|
||||
|
||||
# Overwriting to add `is_decoder` flag
|
||||
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
||||
config, inputs = super().prepare_config_and_inputs_for_generate(batch_size)
|
||||
config.is_decoder = True
|
||||
return config, inputs
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Data2VecTextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Data2VecTextConfig, hidden_size=37)
|
||||
@ -402,6 +405,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
config_and_inputs[0]._attn_implementation = "eager"
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
@ -446,6 +450,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
def test_decoder_model_past_with_large_inputs_relative_pos_emb(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
config_and_inputs[0].position_embedding_type = "relative_key"
|
||||
config_and_inputs[0]._attn_implementation = "eager"
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
@ -477,14 +482,14 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
first available non-padding position index is Data2VecTextForTextEmbeddings.padding_idx + 1
|
||||
"""
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
model = Data2VecTextForTextEmbeddings(config=config)
|
||||
model = Data2VecTextEmbeddings(config=config)
|
||||
|
||||
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
|
||||
expected_positions = torch.as_tensor(
|
||||
[[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
|
||||
)
|
||||
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
|
||||
position_ids = Data2VecTextEmbeddings.create_position_ids_from_input_ids(input_ids, model.padding_idx)
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
@ -495,7 +500,7 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
first available non-padding position index is Data2VecTextForTextEmbeddings.padding_idx + 1
|
||||
"""
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
embeddings = Data2VecTextForTextEmbeddings(config=config)
|
||||
embeddings = Data2VecTextEmbeddings(config=config)
|
||||
|
||||
inputs_embeds = torch.empty(2, 4, 30)
|
||||
expected_single_positions = [
|
||||
@ -505,10 +510,17 @@ class Data2VecTextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
3 + embeddings.padding_idx + 1,
|
||||
]
|
||||
expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
|
||||
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds, embeddings.padding_idx)
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
@unittest.skip("Data2VecText token type ids does not work with the flash attention position ids")
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Data2VecText token type ids does not work with the flash attention position ids")
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
|
||||
pass
|
||||
|
||||
@require_torch
|
||||
class Data2VecTextModelIntegrationTest(TestCasePlus):
|
||||
|
Loading…
Reference in New Issue
Block a user