mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Remove head mask in generative models (#35786)
* just squash into one commit * delete print
This commit is contained in:
parent
0173a99e73
commit
955e61b0da
@ -57,6 +57,7 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This
|
||||
- Embedding size E is different from hidden size H justified because the embeddings are context independent (one embedding vector represents one token), whereas hidden states are context dependent (one hidden state represents a sequence of tokens) so it's more logical to have H >> E. Also, the embedding matrix is large since it's V x E (V being the vocab size). If E < H, it has less parameters.
|
||||
- Layers are split in groups that share parameters (to save memory).
|
||||
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
|
@ -55,6 +55,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
|
||||
* mask a span of k tokens with a single mask token (a span of 0 tokens is an insertion of a mask token)
|
||||
* permute sentences
|
||||
* rotate the document to make it start at a specific token
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
|
@ -36,6 +36,7 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
|
||||
- BioGPT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left.
|
||||
- BioGPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next token in a sequence. Leveraging this feature allows BioGPT to generate syntactically coherent text as it can be observed in the run_generation.py example script.
|
||||
- The model can take the `past_key_values` (for PyTorch) as input, which is the previously computed key/value attention pairs. Using this (past_key_values or past) value prevents the model from re-computing pre-computed values in the context of text generation. For PyTorch, see past_key_values argument of the BioGptForCausalLM.forward() method for more information on its usage.
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
|
@ -53,6 +53,7 @@ The original code for vision can be found [here](https://github.com/facebookrese
|
||||
- For Data2VecAudio, preprocessing is identical to [`Wav2Vec2Model`], including feature extraction
|
||||
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
|
||||
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
|
@ -46,8 +46,12 @@ The main differences compared to GPT2.
|
||||
- Merge the key and value caches into one (this changes the format of layer_past/ present, does it risk creating problems?)
|
||||
- Use the memory layout (self.num_heads, 3, self.head_dim) instead of `(3, self.num_heads, self.head_dim)` for the QKV tensor with MHA. (prevents an overhead with the merged key and values, but makes the checkpoints incompatible with the original openai-community/gpt2 model).
|
||||
|
||||
|
||||
You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575)
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Combining Starcoder and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
@ -50,7 +50,7 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
||||
- Hubert is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
|
||||
- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded
|
||||
using [`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
|
@ -51,6 +51,9 @@ multilingual it expects the sequences in a certain format: A special language id
|
||||
source and target text. The source text format is `[lang_code] X [eos]`, where `lang_code` is source language
|
||||
id for source text and target language id for target text, with `X` being the source or target text.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
The [`M2M100Tokenizer`] depends on `sentencepiece` so be sure to install it before running the
|
||||
examples. To install `sentencepiece` run `pip install sentencepiece`.
|
||||
|
||||
|
@ -35,6 +35,9 @@ You can find all the original mBART checkpoints under the [AI at Meta](https://h
|
||||
> [!TIP]
|
||||
> Click on the mBART models in the right sidebar for more examples of applying mBART to different language tasks.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
The example below demonstrates how to translate text with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
<hfoptions id="usage">
|
||||
|
@ -62,6 +62,9 @@ python src/transformers/models/musicgen/convert_musicgen_transformers.py \
|
||||
--checkpoint small --pytorch_dump_folder /output/path --safe_serialization
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Generation
|
||||
|
||||
MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly
|
||||
|
@ -44,6 +44,9 @@ There are two key differences with MusicGen:
|
||||
1. The audio prompt is used here as a conditional signal for the generated audio sample, whereas it's used for audio continuation in [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen).
|
||||
2. Conditional text and audio signals are concatenated to the decoder's hidden states instead of being used as a cross-attention signal, as in MusicGen.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Generation
|
||||
|
||||
MusicGen Melody is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default, and can be explicitly specified by setting `do_sample=True` in the call to [`MusicgenMelodyForConditionalGeneration.generate`], or by overriding the model's generation config (see below).
|
||||
|
@ -41,6 +41,9 @@ Tips:
|
||||
- OPT has the same architecture as [`BartDecoder`].
|
||||
- Contrary to GPT2, OPT adds the EOS token `</s>` to the beginning of every prompt.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with OPT. If you're
|
||||
|
@ -40,6 +40,9 @@ The abstract from the paper is the following:
|
||||
|
||||
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
|
@ -46,6 +46,9 @@ This model was contributed by [anton-l](https://huggingface.co/anton-l).
|
||||
- SEWForCTC is fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded using
|
||||
[`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
- [Audio classification task guide](../tasks/audio_classification)
|
||||
|
@ -54,6 +54,9 @@ found [here](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT).
|
||||
decoded using [`Wav2Vec2CTCTokenizer`].
|
||||
- UniSpeechSat performs especially well on speaker verification, speaker identification, and speaker diarization tasks.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
- [Audio classification task guide](../tasks/audio_classification)
|
||||
|
@ -49,6 +49,9 @@ found [here](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech).
|
||||
- UniSpeech model can be fine-tuned using connectionist temporal classification (CTC) so the model output has to be
|
||||
decoded using [`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Resources
|
||||
|
||||
- [Audio classification task guide](../tasks/audio_classification)
|
||||
|
@ -50,6 +50,9 @@ Note: Meta (FAIR) released a new version of [Wav2Vec2-BERT 2.0](https://huggingf
|
||||
- Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be decoded
|
||||
using [`Wav2Vec2CTCTokenizer`].
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
## Using Flash Attention 2
|
||||
|
||||
Flash Attention 2 is an faster, optimized version of the model.
|
||||
|
@ -32,6 +32,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
You can find all the original Whisper checkpoints under the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) collection.
|
||||
|
||||
> [!NOTE]
|
||||
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Whisper models in the right sidebar for more examples of how to apply Whisper to different audio tasks.
|
||||
|
||||
|
@ -367,15 +367,15 @@ class AlbertSdpaAttention(AlbertAttention):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
if self.position_embedding_type != "absolute" or output_attentions:
|
||||
logger.warning(
|
||||
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
|
||||
"the eager attention implementation, but specifying the eager implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(hidden_states, attention_mask, head_mask, output_attentions)
|
||||
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.size()
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
@ -290,10 +290,6 @@ class BartFlashAttention2(BartAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# BartFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("BartFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -400,10 +396,10 @@ class BartSdpaAttention(BartAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -411,7 +407,6 @@ class BartSdpaAttention(BartAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -253,10 +253,10 @@ class BioGptSdpaAttention(BioGptAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -264,7 +264,6 @@ class BioGptSdpaAttention(BioGptAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -352,10 +352,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# Data2VecAudioFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("Data2VecAudioFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -462,10 +458,10 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -473,7 +469,6 @@ class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -391,13 +391,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
|
||||
|
||||
class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
if head_mask is not None:
|
||||
# The super dispatch is done in the forward.
|
||||
raise ValueError(
|
||||
"PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository."
|
||||
)
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None):
|
||||
scale = None
|
||||
if not self.scale_attn_weights:
|
||||
scale = 1
|
||||
@ -507,17 +501,17 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
|
||||
|
||||
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
||||
|
||||
if not output_attentions and head_mask is None:
|
||||
if not output_attentions:
|
||||
# Difference with the original implementation: there is no need to transpose the key here,
|
||||
# as SDPA expects seq_length to be at index -2 for the key as well
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask)
|
||||
else:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None."
|
||||
"GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`."
|
||||
' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
|
||||
attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask)
|
||||
|
||||
if not self.multi_query:
|
||||
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
||||
|
@ -409,10 +409,6 @@ class HubertFlashAttention2(HubertAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# HubertFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("HubertFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -519,10 +515,10 @@ class HubertSdpaAttention(HubertAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -530,7 +526,6 @@ class HubertSdpaAttention(HubertAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -360,10 +360,6 @@ class M2M100FlashAttention2(M2M100Attention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# M2M100FlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("M2M100FlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -471,10 +467,10 @@ class M2M100SdpaAttention(M2M100Attention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"M2M100Model is using M2M100SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"M2M100Model is using M2M100SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -482,7 +478,6 @@ class M2M100SdpaAttention(M2M100Attention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@ -1270,11 +1265,6 @@ class M2M100Model(M2M100PreTrainedModel):
|
||||
self.encoder = M2M100Encoder(config, self.shared)
|
||||
self.decoder = M2M100Decoder(config, self.shared)
|
||||
|
||||
if config._attn_implementation == "flash_attention_2":
|
||||
logger.warning_once(
|
||||
"Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention."
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
@ -297,10 +297,6 @@ class MBartFlashAttention2(MBartAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# MBartFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("MBartFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -408,10 +404,10 @@ class MBartSdpaAttention(MBartAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -419,7 +415,6 @@ class MBartSdpaAttention(MBartAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -333,10 +333,6 @@ class MusicgenFlashAttention2(MusicgenAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# MusicgenFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -443,10 +439,11 @@ class MusicgenSdpaAttention(MusicgenAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -454,7 +451,6 @@ class MusicgenSdpaAttention(MusicgenAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@ -471,7 +467,6 @@ class MusicgenSdpaAttention(MusicgenAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -346,10 +346,6 @@ class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# MusicgenMelodyFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("MusicgenMelodyFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -457,10 +453,10 @@ class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -468,7 +464,6 @@ class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -298,16 +298,15 @@ class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -559,10 +559,6 @@ class SEWFlashAttention2(SEWAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# SEWFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("SEWFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -670,10 +666,10 @@ class SEWSdpaAttention(SEWAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -681,7 +677,6 @@ class SEWSdpaAttention(SEWAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -448,10 +448,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# UniSpeechFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("UniSpeechFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -558,10 +554,10 @@ class UniSpeechSdpaAttention(UniSpeechAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"UniSpeechModel is using UniSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"UniSpeechModel is using UniSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -569,7 +565,6 @@ class UniSpeechSdpaAttention(UniSpeechAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -451,10 +451,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# UniSpeechSatFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("UniSpeechSatFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -561,10 +557,10 @@ class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"UniSpeechSatModel is using UniSpeechSatSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"UniSpeechSatModel is using UniSpeechSatSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -572,7 +568,6 @@ class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -652,10 +652,6 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention):
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# Wav2Vec2FlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("Wav2Vec2FlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
@ -763,10 +759,10 @@ class Wav2Vec2SdpaAttention(Wav2Vec2Attention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -774,7 +770,6 @@ class Wav2Vec2SdpaAttention(Wav2Vec2Attention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
@ -373,9 +373,6 @@ class WhisperFlashAttention2(WhisperAttention):
|
||||
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
||||
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
# WhisperFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
@ -477,10 +474,11 @@ class WhisperSdpaAttention(WhisperAttention):
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions or layer_head_mask is not None:
|
||||
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
|
||||
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
@ -488,7 +486,6 @@ class WhisperSdpaAttention(WhisperAttention):
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
@ -1414,6 +1414,7 @@ class GenerationTesterMixin:
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config._attn_implementation = "eager" # head mask works only in eager mode and will be removed soon
|
||||
text_config = config.get_text_config()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
@ -58,28 +58,16 @@ def prepare_bart_inputs_dict(
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -167,10 +155,9 @@ class BartModelTester:
|
||||
model = BartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
@ -119,11 +119,10 @@ class TFBartModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@ -158,9 +157,6 @@ def prepare_bart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@ -172,20 +168,11 @@ def prepare_bart_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -135,9 +135,7 @@ class BioGptModelTester:
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_biogpt_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
def create_and_check_biogpt_model_attention_mask_past(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = BioGptModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -177,9 +175,7 @@ class BioGptModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_biogpt_model_past_large_inputs(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
def create_and_check_biogpt_model_past_large_inputs(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = BioGptModel(config=config).to(torch_device).eval()
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
@ -213,7 +209,7 @@ class BioGptModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = BioGptForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
@ -233,9 +229,7 @@ class BioGptModelTester:
|
||||
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
|
||||
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
|
||||
|
||||
def create_and_check_biogpt_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
def create_and_check_biogpt_for_token_classification(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
config.num_labels = self.num_labels
|
||||
model = BioGptForTokenClassification(config)
|
||||
model.to(torch_device)
|
||||
|
@ -128,13 +128,10 @@ class GPTBigCodeModelTester:
|
||||
reorder_and_upcast_attn=reorder_and_upcast_attn,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
@ -174,19 +171,19 @@ class GPTBigCodeModelTester:
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -223,7 +220,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_bigcode_model_attention_mask_past(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, *args
|
||||
):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
@ -265,7 +262,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_gpt_bigcode_model_past_large_inputs(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, *args
|
||||
):
|
||||
model = GPTBigCodeModel(config=config)
|
||||
model.to(torch_device)
|
||||
@ -302,7 +299,7 @@ class GPTBigCodeModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, token_type_ids, *args):
|
||||
model = GPTBigCodeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -312,7 +309,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = GPTBigCodeForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
@ -325,7 +322,7 @@ class GPTBigCodeModelTester:
|
||||
result.loss.backward()
|
||||
|
||||
def create_and_check_gpt_bigcode_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTBigCodeForSequenceClassification(config)
|
||||
@ -335,7 +332,7 @@ class GPTBigCodeModelTester:
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_gpt_bigcode_for_token_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
self, config, input_ids, input_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTBigCodeForTokenClassification(config)
|
||||
@ -359,7 +356,6 @@ class GPTBigCodeModelTester:
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
@ -370,7 +366,6 @@ class GPTBigCodeModelTester:
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
@ -51,28 +51,16 @@ def prepare_m2m_100_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -166,10 +154,9 @@ class M2M100ModelTester:
|
||||
model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
@ -55,28 +55,16 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -158,10 +146,9 @@ class MBartModelTester:
|
||||
model = MBartModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
|
@ -107,11 +107,10 @@ class TFMBartModelTester:
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
@ -123,9 +122,6 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
@ -137,20 +133,11 @@ def prepare_mbart_inputs_dict(
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -76,27 +76,19 @@ def prepare_musicgen_decoder_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
|
||||
attention_mask = attention_mask.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
|
||||
if encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -467,9 +459,6 @@ def prepare_musicgen_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
@ -477,26 +466,11 @@ def prepare_musicgen_inputs_dict(
|
||||
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
|
||||
)[:, 0, :]
|
||||
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(
|
||||
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(
|
||||
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(
|
||||
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
@ -80,15 +80,12 @@ def prepare_musicgen_melody_decoder_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.reshape(-1, config.num_codebooks, input_ids.shape[-1])[:, 0, :]
|
||||
attention_mask = attention_mask.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.num_hidden_layers, config.num_attention_heads, device=torch_device)
|
||||
if encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_states.shape[:2], device=torch_device)
|
||||
return {
|
||||
@ -96,7 +93,6 @@ def prepare_musicgen_melody_decoder_inputs_dict(
|
||||
"attention_mask": attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -475,8 +471,6 @@ def prepare_musicgen_melody_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
@ -484,21 +478,11 @@ def prepare_musicgen_melody_inputs_dict(
|
||||
-1, config.decoder.num_codebooks, decoder_input_ids.shape[-1]
|
||||
)[:, 0, :]
|
||||
decoder_attention_mask = decoder_attention_mask.ne(config.decoder.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(
|
||||
config.text_encoder.num_hidden_layers, config.text_encoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(
|
||||
config.decoder.num_hidden_layers, config.decoder.num_attention_heads, device=torch_device
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
@ -52,15 +52,12 @@ def prepare_opt_inputs_dict(
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -156,10 +153,9 @@ class OPTModelTester:
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@ -187,7 +183,7 @@ class OPTModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
# test no attention_mask works
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
_, past_key_values = outputs.to_tuple()
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
|
||||
|
@ -62,25 +62,13 @@ def prepare_whisper_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = tf.where(decoder_input_ids != config.pad_token_id, 1, 0)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -350,9 +338,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("fp16 is not yet supported for TF models")
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -159,26 +159,14 @@ def prepare_whisper_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
# "input_ids": input_features,
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -3235,12 +3223,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue((eager_generated_ids[permutation_idx, :] == static_generated_ids).all())
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
return {"input_features": input_features, "head_mask": head_mask}
|
||||
|
||||
|
||||
@require_torch
|
||||
class WhisperEncoderModelTester:
|
||||
def __init__(
|
||||
@ -3314,10 +3296,7 @@ class WhisperEncoderModelTester:
|
||||
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_whisper_encoder_inputs_dict(
|
||||
config,
|
||||
input_features=input_features,
|
||||
)
|
||||
inputs_dict = {"input_features": input_features}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@ -3427,8 +3406,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
encoder_inputs = {"input_features": inputs["input_features"]}
|
||||
del inputs["input_features"]
|
||||
|
||||
if "head_mask" in inputs:
|
||||
encoder_inputs["head_mask"] = inputs["head_mask"]
|
||||
if "attention_mask" in inputs:
|
||||
encoder_inputs["attention_mask"] = inputs["attention_mask"]
|
||||
if "output_attentions" in inputs:
|
||||
@ -3523,9 +3500,6 @@ class WhisperStandaloneDecoderModelTester:
|
||||
)
|
||||
|
||||
inputs_dict.pop("input_features")
|
||||
inputs_dict.pop("head_mask")
|
||||
inputs_dict.pop("decoder_head_mask")
|
||||
inputs_dict.pop("cross_attn_head_mask")
|
||||
|
||||
inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask")
|
||||
inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids")
|
||||
|
@ -1444,6 +1444,7 @@ class ModelTesterMixin:
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init._attn_implementation = "eager" # head mask works only in eager mode and will be removed soon
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user