mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
correct composition of padding and causal masks
This commit is contained in:
parent
4e0f24348f
commit
638fe7f5a4
@ -288,8 +288,8 @@ class BertAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
@ -350,7 +350,6 @@ class BertLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# NOTE I think we may need to call encoder_hidden_states[i] for each layer
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
@ -365,7 +364,8 @@ class BertEncoder(nn.Module):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
|
||||
encoder_hidden_state = encoder_hidden_states[i]
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
@ -607,22 +607,26 @@ class BertModel(BertPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||
head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
|
||||
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
""" Forward pass on the Model.
|
||||
|
||||
The values of the attention matrix (shape [batch_size, seq_length])
|
||||
should be 1.0 for the position we want to attend to and 0. for the ones
|
||||
we do not want to attend to.
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
as a decoder, in which case a layer of cross-attention is added between
|
||||
ever self-attention layer, following the architecture described in [1].
|
||||
|
||||
To behave like as a decoder the model needs to be initialized with the
|
||||
`is_decoder` argument of the config set to `True`. An
|
||||
`encoder_hidden_state` is expected as an input to the forward pass.
|
||||
`encoder_hidden_states` is expected as an input to the forward pass.
|
||||
When a decoder, there are two kinds of attention masks to specify:
|
||||
|
||||
(1) Self-attention masks that need to be causal (only attends to
|
||||
previous tokens);
|
||||
(2) A cross-attention mask that prevents the module
|
||||
from attending to the encoder' padding tokens.
|
||||
from attending to the encoder's padding tokens.
|
||||
|
||||
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
|
||||
neural information processing systems. 2017.
|
||||
@ -632,20 +636,20 @@ class BertModel(BertPreTrainedModel):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just make it broadcastable to all heads.
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
|
||||
# provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
# - if decoder, make it causal
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if attention_mask.dim() == 2:
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_ids.size()
|
||||
seq_ids = torch.arange(seq_length)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :]
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
@ -676,7 +680,7 @@ class BertModel(BertPreTrainedModel):
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_state=encoder_hidden_state,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
Loading…
Reference in New Issue
Block a user