mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
fix syntax errors
This commit is contained in:
parent
3e1cd8241e
commit
fa218e648a
@ -201,7 +201,7 @@ class BertSelfAttention(nn.Module):
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
if encoder_hidden_states: # if encoder-decoder attention
|
||||
if encoder_hidden_states is not None: # if encoder-decoder attention
|
||||
mixed_query_layer = self.query(encoder_hidden_states)
|
||||
else:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
@ -331,11 +331,12 @@ class BertLayer(nn.Module):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
if encoder_hidden_state:
|
||||
if encoder_hidden_state is not None:
|
||||
try:
|
||||
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||
except AttributeError as ae:
|
||||
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
|
||||
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
|
||||
raise
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
@ -382,7 +383,7 @@ class BertDecoder(nn.Module):
|
||||
config.is_decoder = True
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
@ -738,7 +739,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
@ -782,7 +783,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
||||
sequence_output = decoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@ -1387,8 +1388,7 @@ class Bert2Rnd(BertPreTrainedModel):
|
||||
head_mask=head_mask)
|
||||
encoder_output = encoder_outputs[0]
|
||||
|
||||
decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range)
|
||||
decoder_outputs = self.decoder(decoder_input,
|
||||
decoder_outputs = self.decoder(input_ids,
|
||||
encoder_output,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user