correct composition of padding and causal masks

This commit is contained in:
Rémi Louf 2019-10-17 10:13:07 +02:00
parent 4e0f24348f
commit 638fe7f5a4

View File

@ -288,8 +288,8 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
@ -350,7 +350,6 @@ class BertLayer(nn.Module):
return outputs
# NOTE I think we may need to call encoder_hidden_states[i] for each layer
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
@ -365,7 +364,8 @@ class BertEncoder(nn.Module):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
encoder_hidden_state = encoder_hidden_states[i]
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
hidden_states = layer_outputs[0]
if self.output_attentions:
@ -607,22 +607,26 @@ class BertModel(BertPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
""" Forward pass on the Model.
The values of the attention matrix (shape [batch_size, seq_length])
should be 1.0 for the position we want to attend to and 0. for the ones
we do not want to attend to.
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
ever self-attention layer, following the architecture described in [1].
To behave like as a decoder the model needs to be initialized with the
`is_decoder` argument of the config set to `True`. An
`encoder_hidden_state` is expected as an input to the forward pass.
`encoder_hidden_states` is expected as an input to the forward pass.
When a decoder, there are two kinds of attention masks to specify:
(1) Self-attention masks that need to be causal (only attends to
previous tokens);
(2) A cross-attention mask that prevents the module
from attending to the encoder' padding tokens.
from attending to the encoder's padding tokens.
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
neural information processing systems. 2017.
@ -632,20 +636,20 @@ class BertModel(BertPreTrainedModel):
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just make it broadcastable to all heads.
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
# provided a padding mask of dimensions [batch_size, seq_length]
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
# - if decoder, make it causal
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2:
if self.config.is_decoder:
batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
@ -676,7 +680,7 @@ class BertModel(BertPreTrainedModel):
encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_state=encoder_hidden_state,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)