fix syntax errors

This commit is contained in:
Rémi Louf 2019-10-10 15:16:07 +02:00
parent 3e1cd8241e
commit fa218e648a

View File

@ -201,7 +201,7 @@ class BertSelfAttention(nn.Module):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
if encoder_hidden_states: # if encoder-decoder attention
if encoder_hidden_states is not None: # if encoder-decoder attention
mixed_query_layer = self.query(encoder_hidden_states)
else:
mixed_query_layer = self.query(hidden_states)
@ -331,11 +331,12 @@ class BertLayer(nn.Module):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = attention_outputs[0]
if encoder_hidden_state:
if encoder_hidden_state is not None:
try:
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
except AttributeError as ae:
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
raise
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
@ -382,7 +383,7 @@ class BertDecoder(nn.Module):
config.is_decoder = True
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
all_hidden_states = ()
@ -738,7 +739,7 @@ class BertDecoderModel(BertPreTrainedModel):
self.decoder.layer[layer].attention.prune_heads(heads)
self.decoder.layer[layer].crossattention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@ -782,7 +783,7 @@ class BertDecoderModel(BertPreTrainedModel):
sequence_output = decoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@ -1387,8 +1388,7 @@ class Bert2Rnd(BertPreTrainedModel):
head_mask=head_mask)
encoder_output = encoder_outputs[0]
decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range)
decoder_outputs = self.decoder(decoder_input,
decoder_outputs = self.decoder(input_ids,
encoder_output,
token_type_ids=token_type_ids,
position_ids=position_ids,