mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix cross-attention head mask for Torch encoder-decoder models (#10605)
* Fix cross-attention head mask for Torch BART models * Fix head masking for cross-attention module for the following models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart, Pegasus * Enable test_headmasking for M2M_100 model * Fix cross_head_mask for FSMT, LED and T5 * This commit fixes `head_mask` for cross-attention modules in the following models: FSMT, LED, T5 * It also contains some smaller changes in doc so that it is be perfectly clear the shape of `cross_head_mask` is the same as of `decoder_head_mask` * Update template * Fix template for BartForCausalLM * Fix cross_head_mask for Speech2Text models * Fix cross_head_mask in templates * Fix args order in BartForCausalLM template * Fix doc in BART templates * Make more explicit naming * `cross_head_mask` -> `cross_attn_head_mask` * `cross_layer_head_mask` -> `cross_attn_layer_head_mask` * Fix doc * make style quality * Fix speech2text docstring
This commit is contained in:
parent
ca6b80cadb
commit
e3ff165aa5
@ -296,7 +296,7 @@ class BartEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -368,7 +368,7 @@ class BartDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -382,9 +382,9 @@ class BartDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -419,7 +419,7 @@ class BartDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -598,18 +598,25 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and
|
||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||
information on the default strategy.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -710,11 +717,11 @@ class BartEncoder(BartPretrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -875,7 +882,7 @@ class BartDecoder(BartPretrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -912,18 +919,18 @@ class BartDecoder(BartPretrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -993,11 +1000,12 @@ class BartDecoder(BartPretrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -1031,7 +1039,7 @@ class BartDecoder(BartPretrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1042,7 +1050,9 @@ class BartDecoder(BartPretrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1123,6 +1133,7 @@ class BartModel(BartPretrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1172,7 +1183,7 @@ class BartModel(BartPretrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1248,6 +1259,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1282,6 +1294,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1386,6 +1399,7 @@ class BartForSequenceClassification(BartPretrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -1416,6 +1430,7 @@ class BartForSequenceClassification(BartPretrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1496,6 +1511,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
@ -1527,6 +1543,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1633,7 +1650,7 @@ class BartForCausalLM(BartPretrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -1666,18 +1683,17 @@ class BartForCausalLM(BartPretrainedModel):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1734,7 +1750,7 @@ class BartForCausalLM(BartPretrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -298,7 +298,7 @@ class BlenderbotEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -371,7 +371,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -385,9 +385,9 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -423,7 +423,7 @@ class BlenderbotDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -554,18 +554,25 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -666,11 +673,11 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -834,7 +841,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -871,18 +878,19 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -951,11 +959,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -989,7 +998,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1000,7 +1009,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1090,6 +1101,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1147,7 +1159,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1241,6 +1253,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1275,6 +1288,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1395,7 +1409,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -1428,18 +1442,17 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1496,7 +1509,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -296,7 +296,7 @@ class BlenderbotSmallEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -369,7 +369,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -383,9 +383,9 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -420,7 +420,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -555,18 +555,25 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -667,11 +674,11 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -834,7 +841,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -871,18 +878,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -953,10 +960,12 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -990,7 +999,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1001,7 +1010,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1077,6 +1088,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1134,7 +1146,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1216,6 +1228,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1250,6 +1263,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1370,7 +1384,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -1403,18 +1417,17 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1471,7 +1484,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -248,17 +248,25 @@ FSMT_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
@ -573,7 +581,7 @@ class DecoderLayer(nn.Module):
|
||||
layer_state=None,
|
||||
causal_mask=None,
|
||||
layer_head_mask=None,
|
||||
encoder_layer_head_mask=None,
|
||||
cross_attn_layer_head_mask=None,
|
||||
decoder_padding_mask=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
@ -604,7 +612,7 @@ class DecoderLayer(nn.Module):
|
||||
key=encoder_hidden_states,
|
||||
key_padding_mask=encoder_attn_mask,
|
||||
layer_state=layer_state, # mutates layer state
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
@ -666,7 +674,7 @@ class FSMTDecoder(nn.Module):
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
@ -690,12 +698,11 @@ class FSMTDecoder(nn.Module):
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
Returns:
|
||||
BaseModelOutputWithPast or tuple:
|
||||
@ -732,10 +739,11 @@ class FSMTDecoder(nn.Module):
|
||||
next_decoder_cache = []
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -756,7 +764,7 @@ class FSMTDecoder(nn.Module):
|
||||
layer_state=layer_state,
|
||||
causal_mask=decoder_causal_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
@ -1009,6 +1017,7 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs: Optional[Tuple] = None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
@ -1065,7 +1074,7 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -1143,6 +1152,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
labels=None,
|
||||
@ -1173,6 +1183,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -901,7 +901,7 @@ class LEDEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
"""
|
||||
residual = hidden_states
|
||||
attn_outputs = self.self_attn(
|
||||
@ -968,7 +968,7 @@ class LEDDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -982,9 +982,9 @@ class LEDDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(decoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`): Whether the base model outputs attentions.
|
||||
This requires the attentions tensor to be reshaped in this function.
|
||||
@ -1018,7 +1018,7 @@ class LEDDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -1199,17 +1199,6 @@ class LEDSeq2SeqModelOutput(ModelOutput):
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
@ -1221,8 +1210,6 @@ class LEDSeq2SeqModelOutput(ModelOutput):
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1278,17 +1265,6 @@ class LEDSeq2SeqLMOutput(ModelOutput):
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@ -1301,8 +1277,6 @@ class LEDSeq2SeqLMOutput(ModelOutput):
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1358,17 +1332,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@ -1381,8 +1344,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1440,17 +1401,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
Global attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads. Those are the attention weights from every token with global attention to every token
|
||||
in the sequence.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@ -1464,8 +1414,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
head_mask: Optional[torch.FloatTensor] = None
|
||||
decoder_head_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
LED_START_DOCSTRING = r"""
|
||||
@ -1547,17 +1495,24 @@ LED_INPUTS_DOCSTRING = r"""
|
||||
|
||||
- 0 for local attention (a sliding window attention),
|
||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -1730,7 +1685,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
||||
|
||||
- 0 for local attention (a sliding window attention),
|
||||
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
@ -1914,7 +1869,7 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -1961,18 +1916,17 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -2052,11 +2006,12 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -2090,7 +2045,7 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -2100,7 +2055,9 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -2180,6 +2137,7 @@ class LEDModel(LEDPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
@ -2224,7 +2182,7 @@ class LEDModel(LEDPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -2306,6 +2264,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
@ -2358,6 +2317,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -2463,6 +2423,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
@ -2495,6 +2456,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -2571,6 +2533,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
global_attention_mask=None,
|
||||
start_positions=None,
|
||||
@ -2604,6 +2567,7 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
|
||||
global_attention_mask=global_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
|
@ -367,7 +367,7 @@ class M2M100EncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -440,7 +440,7 @@ class M2M100DecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -454,9 +454,9 @@ class M2M100DecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -492,7 +492,7 @@ class M2M100DecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -603,6 +603,24 @@ M2M_100_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -704,6 +722,12 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@ -841,7 +865,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -878,6 +902,19 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@ -955,11 +992,12 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -993,7 +1031,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1004,7 +1042,9 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1085,6 +1125,7 @@ class M2M100Model(M2M100PreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1126,7 +1167,7 @@ class M2M100Model(M2M100PreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1201,6 +1242,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1249,6 +1291,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1281,7 +1324,14 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
@ -1293,6 +1343,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
@ -313,7 +313,7 @@ class MarianEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -386,7 +386,7 @@ class MarianDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -400,9 +400,9 @@ class MarianDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -437,7 +437,7 @@ class MarianDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -567,18 +567,25 @@ MARIAN_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -678,11 +685,11 @@ class MarianEncoder(MarianPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -842,7 +849,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -879,18 +886,18 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -959,11 +966,12 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -997,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1008,7 +1016,9 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1084,6 +1094,7 @@ class MarianModel(MarianPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1142,7 +1153,7 @@ class MarianModel(MarianPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1229,6 +1240,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1264,6 +1276,7 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1391,7 +1404,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -1424,18 +1437,17 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1492,7 +1504,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -303,7 +303,7 @@ class MBartEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -375,7 +375,7 @@ class MBartDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -389,9 +389,9 @@ class MBartDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -427,7 +427,7 @@ class MBartDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -595,18 +595,25 @@ MBART_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -708,11 +715,11 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -877,7 +884,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -914,18 +921,18 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -995,11 +1002,12 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -1033,7 +1041,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1044,7 +1052,9 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1127,6 +1137,7 @@ class MBartModel(MBartPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1173,7 +1184,7 @@ class MBartModel(MBartPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1254,6 +1265,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1287,6 +1299,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1384,6 +1397,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -1414,6 +1428,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1495,6 +1510,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
@ -1526,6 +1542,7 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1634,7 +1651,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -1667,18 +1684,17 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1735,7 +1751,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -313,7 +313,7 @@ class PegasusEncoderLayer(nn.Module):
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
@ -386,7 +386,7 @@ class PegasusDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -400,9 +400,9 @@ class PegasusDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -438,7 +438,7 @@ class PegasusDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -566,18 +566,25 @@ PEGASUS_INPUTS_DOCSTRING = r"""
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -679,11 +686,11 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -848,7 +855,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -885,18 +892,18 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
|
||||
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -965,11 +972,12 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -1003,7 +1011,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1014,7 +1022,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1092,6 +1102,7 @@ class PegasusModel(PegasusPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1150,7 +1161,7 @@ class PegasusModel(PegasusPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1232,6 +1243,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1267,6 +1279,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -1390,7 +1403,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -1423,18 +1436,17 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1491,7 +1503,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -451,7 +451,7 @@ class Speech2TextDecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -465,9 +465,9 @@ class Speech2TextDecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
:obj:`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size :obj:`(config.encoder_attention_heads,)`.
|
||||
:obj:`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -503,7 +503,7 @@ class Speech2TextDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -623,19 +623,29 @@ SPEECH_TO_TEXT_INPUTS_DOCSTRING = r"""
|
||||
:obj:`past_key_values`).
|
||||
decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
also be used by default. <<<<<<< HEAD
|
||||
|
||||
If you want to change padding behavior, you should read
|
||||
:func:`modeling_speech_to_text._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the
|
||||
paper <https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -728,11 +738,11 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -884,7 +894,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -921,18 +931,18 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -1001,12 +1011,12 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -1039,7 +1049,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -1050,7 +1060,9 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -1127,6 +1139,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -1166,7 +1179,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -1240,6 +1253,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -1296,6 +1310,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -607,7 +607,7 @@ class T5Block(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
encoder_decoder_position_bias=None,
|
||||
layer_head_mask=None,
|
||||
encoder_layer_head_mask=None,
|
||||
cross_attn_layer_head_mask=None,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
@ -661,7 +661,7 @@ class T5Block(nn.Module):
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
query_length=query_length,
|
||||
use_cache=use_cache,
|
||||
@ -846,7 +846,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
encoder_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@ -913,7 +913,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
|
||||
# Prepare head mask if needed
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
||||
encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers)
|
||||
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
|
||||
present_key_value_states = () if use_cache else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
@ -925,7 +925,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
|
||||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
||||
layer_head_mask = head_mask[i]
|
||||
encoder_layer_head_mask = encoder_head_mask[i]
|
||||
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
@ -942,8 +942,8 @@ class T5Stack(T5PreTrainedModel):
|
||||
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
||||
if layer_head_mask is not None:
|
||||
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
||||
if encoder_layer_head_mask is not None:
|
||||
encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device)
|
||||
if cross_attn_layer_head_mask is not None:
|
||||
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@ -955,7 +955,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
layer_head_mask=layer_head_mask,
|
||||
encoder_layer_head_mask=encoder_layer_head_mask,
|
||||
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -1082,12 +1082,19 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0,
|
||||
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
|
||||
``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||
@ -1263,6 +1270,7 @@ class T5Model(T5PreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1338,7 +1346,7 @@ class T5Model(T5PreTrainedModel):
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -1451,6 +1459,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -1551,7 +1560,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
@ -1041,10 +1041,11 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
@ -1876,7 +1877,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
encoder_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
@ -1890,9 +1891,9 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(config.encoder_attention_heads,)`.
|
||||
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
|
||||
size `(config.encoder_attention_heads,)`.
|
||||
`(encoder_attention_heads,)`.
|
||||
cross_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
@ -1927,7 +1928,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=encoder_layer_head_mask,
|
||||
layer_head_mask=cross_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -2070,18 +2071,24 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and
|
||||
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
|
||||
information on the default strategy.
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
|
||||
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||
@ -2211,10 +2218,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
@ -2377,7 +2385,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
@ -2414,18 +2422,17 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -2493,12 +2500,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.size()[0] == (
|
||||
len(self.layers)
|
||||
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
@ -2529,7 +2536,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
encoder_head_mask[idx] if encoder_head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
@ -2540,7 +2547,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
|
||||
cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@ -2621,6 +2628,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -2662,7 +2670,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
encoder_head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
@ -2743,6 +2751,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@ -2791,6 +2800,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
@ -3124,7 +3134,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
@ -3157,18 +3167,17 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the heas is **masked**.
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
@ -3225,7 +3234,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_head_mask=encoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -55,6 +55,7 @@ def prepare_bart_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -64,6 +65,8 @@ def prepare_bart_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -71,6 +74,7 @@ def prepare_bart_inputs_dict(
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ def prepare_blenderbot_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -54,6 +55,8 @@ def prepare_blenderbot_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -61,6 +64,7 @@ def prepare_blenderbot_inputs_dict(
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -50,6 +50,7 @@ def prepare_blenderbot_small_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -59,6 +60,8 @@ def prepare_blenderbot_small_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -66,6 +69,7 @@ def prepare_blenderbot_small_inputs_dict(
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -225,8 +225,8 @@ class ModelTesterMixin:
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" in arg_names
|
||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
@ -492,6 +492,8 @@ class ModelTesterMixin:
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
|
||||
inputs["decoder_head_mask"] = head_mask
|
||||
if "cross_attn_head_mask" in arg_names:
|
||||
inputs["cross_attn_head_mask"] = head_mask
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
# Test that we can get a gradient back for importance score computation
|
||||
@ -523,6 +525,7 @@ class ModelTesterMixin:
|
||||
if model.config.is_encoder_decoder:
|
||||
check_attentions_validity(outputs.encoder_attentions)
|
||||
check_attentions_validity(outputs.decoder_attentions)
|
||||
check_attentions_validity(outputs.cross_attentions)
|
||||
else:
|
||||
check_attentions_validity(outputs.attentions)
|
||||
|
||||
@ -1093,7 +1096,7 @@ class ModelTesterMixin:
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
|
||||
blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
|
@ -113,6 +113,7 @@ def prepare_fsmt_inputs_dict(
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -120,6 +121,8 @@ def prepare_fsmt_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
@ -52,6 +52,7 @@ def prepare_led_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -61,6 +62,8 @@ def prepare_led_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -68,6 +71,7 @@ def prepare_led_inputs_dict(
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -41,16 +41,28 @@ def prepare_m2m_100_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -142,9 +154,10 @@ class M2M100ModelTester:
|
||||
model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@ -217,7 +230,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -60,6 +60,7 @@ def prepare_marian_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -69,6 +70,8 @@ def prepare_marian_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -76,6 +79,7 @@ def prepare_marian_inputs_dict(
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -52,6 +52,7 @@ def prepare_mbart_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -61,6 +62,8 @@ def prepare_mbart_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -68,6 +71,7 @@ def prepare_mbart_inputs_dict(
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -42,6 +42,7 @@ def prepare_pegasus_inputs_dict(
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
@ -51,6 +52,8 @@ def prepare_pegasus_inputs_dict(
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@ -58,6 +61,7 @@ def prepare_pegasus_inputs_dict(
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
|
@ -55,17 +55,29 @@ def prepare_speech_to_text_inputs_dict(
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_features.ne(0)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||
return {
|
||||
# "input_ids": input_features,
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@ -247,7 +259,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
test_torchscript = True
|
||||
|
||||
@ -316,8 +327,8 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" in arg_names
|
||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
Loading…
Reference in New Issue
Block a user