mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[RoBERTa-based] Add support for sdpa (#30510)
* Adding SDPA support for RoBERTa-based models * add not is_cross_attention * fix copies * fix test * add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin * address some review comments * use copied from * style * consistency * fix lists --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
e0b87b0f40
commit
f1a385b1de
@ -74,7 +74,6 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||
@ -86,6 +85,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
|
||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||
|
||||
@ -204,9 +204,11 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||
* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel)
|
||||
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
|
||||
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
|
||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||
@ -216,10 +218,16 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
||||
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
|
||||
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
|
||||
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
|
||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
|
||||
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
|
||||
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
|
||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||
@ -233,6 +241,14 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
|
||||
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
|
||||
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
|
||||
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)
|
||||
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
|
||||
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
|
||||
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
|
||||
@ -243,12 +259,9 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
|
||||
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
|
||||
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
|
||||
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
|
||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
|
||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
|
||||
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
|
||||
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
||||
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
||||
|
||||
|
||||
|
@ -1063,7 +1063,7 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -20,10 +20,15 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -40,6 +45,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
get_torch_version,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -294,6 +300,108 @@ class CamembertSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert
|
||||
class CamembertSdpaSelfAttention(CamembertSelfAttention):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__(config, position_embedding_type=position_embedding_type)
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||
|
||||
# Adapted from CamembertSelfAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
||||
logger.warning_once(
|
||||
"CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"the manual attention implementation, but specifying the manual implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert
|
||||
class CamembertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -311,6 +419,7 @@ class CamembertSelfOutput(nn.Module):
|
||||
|
||||
CAMEMBERT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": CamembertSelfAttention,
|
||||
"sdpa": CamembertSdpaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
@ -603,6 +712,7 @@ class CamembertPreTrainedModel(PreTrainedModel):
|
||||
config_class = CamembertConfig
|
||||
base_model_prefix = "roberta"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
@ -749,7 +859,7 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
|
||||
_no_split_modules = []
|
||||
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaModel.__init__ with Roberta->Camembert
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -759,6 +869,9 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
|
||||
self.pooler = CamembertPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.attn_implementation = config._attn_implementation
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -782,7 +895,7 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -846,9 +959,6 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
@ -857,9 +967,43 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
||||
|
||||
use_sdpa_attention_masks = (
|
||||
self.attn_implementation == "sdpa"
|
||||
and self.position_embedding_type == "absolute"
|
||||
and head_mask is None
|
||||
and not output_attentions
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
if self.config.is_decoder:
|
||||
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
embedding_output,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
@ -868,7 +1012,15 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
@ -879,13 +1031,6 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
|
@ -20,10 +20,15 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -40,6 +45,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
get_torch_version,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -276,6 +282,108 @@ class RobertaSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta
|
||||
class RobertaSdpaSelfAttention(RobertaSelfAttention):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__(config, position_embedding_type=position_embedding_type)
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||
|
||||
# Adapted from RobertaSelfAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
||||
logger.warning_once(
|
||||
"RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"the manual attention implementation, but specifying the manual implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
|
||||
class RobertaSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -293,6 +401,7 @@ class RobertaSelfOutput(nn.Module):
|
||||
|
||||
ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
"eager": RobertaSelfAttention,
|
||||
"sdpa": RobertaSdpaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
@ -585,7 +694,8 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention"]
|
||||
_no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"]
|
||||
_supports_sdpa = True
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
@ -676,23 +786,22 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->Roberta, BERT->ROBERTA
|
||||
class RobertaModel(RobertaPreTrainedModel):
|
||||
"""
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
||||
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
||||
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
|
||||
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Roberta
|
||||
_no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -702,6 +811,9 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
|
||||
self.pooler = RobertaPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.attn_implementation = config._attn_implementation
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -725,7 +837,6 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -789,9 +900,6 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
@ -800,9 +908,43 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
||||
|
||||
use_sdpa_attention_masks = (
|
||||
self.attn_implementation == "sdpa"
|
||||
and self.position_embedding_type == "absolute"
|
||||
and head_mask is None
|
||||
and not output_attentions
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
if self.config.is_decoder:
|
||||
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
embedding_output,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
@ -811,7 +953,15 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
@ -822,13 +972,6 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
|
@ -569,7 +569,6 @@ class RobertaPreLayerNormPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
|
||||
class RobertaPreLayerNormPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
|
@ -20,10 +20,15 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -40,6 +45,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
get_torch_version,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -277,6 +283,108 @@ class XLMRobertaSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->XLMRoberta
|
||||
class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__(config, position_embedding_type=position_embedding_type)
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||
|
||||
# Adapted from XLMRobertaSelfAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
||||
logger.warning_once(
|
||||
"XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"the manual attention implementation, but specifying the manual implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta
|
||||
class XLMRobertaSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -294,6 +402,7 @@ class XLMRobertaSelfOutput(nn.Module):
|
||||
|
||||
XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
"eager": XLMRobertaSelfAttention,
|
||||
"sdpa": XLMRobertaSdpaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
@ -587,7 +696,8 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
||||
config_class = XLMRobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"]
|
||||
_no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"]
|
||||
_supports_sdpa = True
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
@ -682,19 +792,17 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
"""
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
||||
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
||||
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
|
||||
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta
|
||||
_no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaLayer"]
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -704,6 +812,9 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
|
||||
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.attn_implementation = config._attn_implementation
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -727,7 +838,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -791,9 +901,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
@ -802,9 +909,43 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
||||
|
||||
use_sdpa_attention_masks = (
|
||||
self.attn_implementation == "sdpa"
|
||||
and self.position_embedding_type == "absolute"
|
||||
and head_mask is None
|
||||
and not output_attentions
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
if self.config.is_decoder:
|
||||
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
embedding_output,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
@ -813,7 +954,15 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
@ -824,13 +973,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
|
@ -19,10 +19,15 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
@ -39,6 +44,7 @@ from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
get_torch_version,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -274,6 +280,108 @@ class XLMRobertaXLSelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->XLMRobertaXL
|
||||
class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__(config, position_embedding_type=position_embedding_type)
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||
|
||||
# Adapted from XLMRobertaXLSelfAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
||||
logger.warning_once(
|
||||
"XLMRobertaXLSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"the manual attention implementation, but specifying the manual implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
class XLMRobertaXLSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -287,11 +395,19 @@ class XLMRobertaXLSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
XLMROBERTAXL_SELF_ATTENTION_CLASSES = {
|
||||
"eager": XLMRobertaXLSelfAttention,
|
||||
"sdpa": XLMRobertaXLSdpaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
class XLMRobertaXLAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.self = XLMRobertaXLSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = XLMROBERTAXL_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = XLMRobertaXLSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -573,6 +689,7 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel):
|
||||
config_class = XLMRobertaXLConfig
|
||||
base_model_prefix = "roberta"
|
||||
_no_split_modules = ["XLMRobertaXLEmbeddings", "XLMRobertaXLLayer"]
|
||||
_supports_sdpa = True
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
@ -651,18 +768,22 @@ XLM_ROBERTA_XL_INPUTS_DOCSTRING = r"""
|
||||
"The bare XLM-RoBERTa-XL Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
XLM_ROBERTA_XL_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->XLMRobertaXL, BERT->XLM_ROBERTA_XL
|
||||
class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
"""
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
||||
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder`
|
||||
argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with
|
||||
both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as
|
||||
an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
||||
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL
|
||||
_no_split_modules = ["XLMRobertaXLEmbeddings", "XLMRobertaXLLayer"]
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -672,6 +793,9 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
|
||||
self.pooler = XLMRobertaXLPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.attn_implementation = config._attn_implementation
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -695,7 +819,6 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -759,9 +882,6 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
@ -770,9 +890,43 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
||||
|
||||
use_sdpa_attention_masks = (
|
||||
self.attn_implementation == "sdpa"
|
||||
and self.position_embedding_type == "absolute"
|
||||
and head_mask is None
|
||||
and not output_attentions
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
if self.config.is_decoder:
|
||||
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
embedding_output,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
@ -781,7 +935,15 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
@ -792,13 +954,6 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
|
@ -16,7 +16,14 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -31,7 +38,7 @@ if is_torch_available():
|
||||
class CamembertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_output_embeds_base_model(self):
|
||||
model = CamembertModel.from_pretrained("almanach/camembert-base")
|
||||
model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
@ -54,3 +61,24 @@ class CamembertModelIntegrationTest(unittest.TestCase):
|
||||
# expected_slice = roberta.model.forward(input_ids)[0][:, :3, :3].detach()
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
@require_torch_sdpa
|
||||
def test_output_embeds_base_model_sdpa(self):
|
||||
input_ids = torch.tensor(
|
||||
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]],
|
||||
device=torch_device,
|
||||
dtype=torch.long,
|
||||
) # J'aime le camembert !
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.0254, 0.0235, 0.1027], [0.0606, -0.1811, -0.0418], [-0.1561, -0.1127, 0.2687]]],
|
||||
device=torch_device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
|
||||
model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="sdpa").to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)["last_hidden_state"].detach()
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
@ -17,7 +17,13 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
|
||||
from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -32,7 +38,7 @@ if is_torch_available():
|
||||
class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_xlm_roberta_base(self):
|
||||
model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="eager")
|
||||
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
|
||||
# The dog is cute and lives in the garden house
|
||||
|
||||
@ -49,6 +55,23 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_xlm_roberta_base_sdpa(self):
|
||||
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
|
||||
# The dog is cute and lives in the garden house
|
||||
|
||||
expected_output_shape = torch.Size((1, 12, 768)) # batch_size, sequence_length, embedding_vector_dim
|
||||
expected_output_values_last_dim = torch.tensor(
|
||||
[[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]]
|
||||
)
|
||||
|
||||
model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="sdpa")
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)["last_hidden_state"].detach()
|
||||
self.assertEqual(output.shape, expected_output_shape)
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_xlm_roberta_large(self):
|
||||
model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-large")
|
||||
|
@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import XLMRobertaXLConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -515,6 +516,84 @@ class XLMRobertaXLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
# TODO: Remove this and use the parent method (in common tests) once XLM RoBERTa XL supports low_cpu_mem_usage=True.
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
# Ignore copy
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=False,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
# Ignore copy
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=False,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# Just test that a large cache works as expected
|
||||
res_eager = model_eager.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
res_sdpa = model_sdpa.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
|
||||
@require_torch
|
||||
class XLMRobertaModelXLIntegrationTest(unittest.TestCase):
|
||||
|
@ -70,6 +70,7 @@ def check_sdpa_support_list():
|
||||
"For now, Transformers supports SDPA inference and training for the following architectures:"
|
||||
)[1]
|
||||
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
|
||||
doctext = doctext.lower()
|
||||
|
||||
patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))
|
||||
patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))
|
||||
@ -85,7 +86,7 @@ def check_sdpa_support_list():
|
||||
archs_supporting_sdpa.append(model_name)
|
||||
|
||||
for arch in archs_supporting_sdpa:
|
||||
if arch not in doctext and arch not in doctext.replace("-", "_"):
|
||||
if not any(term in doctext for term in [arch, arch.replace("_", "-"), arch.replace("_", " ")]):
|
||||
raise ValueError(
|
||||
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user