mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
fix syntax errors
This commit is contained in:
parent
3e1cd8241e
commit
fa218e648a
@ -201,7 +201,7 @@ class BertSelfAttention(nn.Module):
|
|||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||||
mixed_key_layer = self.key(hidden_states)
|
mixed_key_layer = self.key(hidden_states)
|
||||||
mixed_value_layer = self.value(hidden_states)
|
mixed_value_layer = self.value(hidden_states)
|
||||||
if encoder_hidden_states: # if encoder-decoder attention
|
if encoder_hidden_states is not None: # if encoder-decoder attention
|
||||||
mixed_query_layer = self.query(encoder_hidden_states)
|
mixed_query_layer = self.query(encoder_hidden_states)
|
||||||
else:
|
else:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@ -331,11 +331,12 @@ class BertLayer(nn.Module):
|
|||||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
if encoder_hidden_state:
|
if encoder_hidden_state is not None:
|
||||||
try:
|
try:
|
||||||
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||||
except AttributeError as ae:
|
except AttributeError as ae:
|
||||||
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
|
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
|
||||||
|
raise
|
||||||
|
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
@ -382,7 +383,7 @@ class BertDecoder(nn.Module):
|
|||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
@ -738,7 +739,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
|||||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||||
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
@ -782,7 +783,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
|||||||
sequence_output = decoder_outputs[0]
|
sequence_output = decoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@ -1387,8 +1388,7 @@ class Bert2Rnd(BertPreTrainedModel):
|
|||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
encoder_output = encoder_outputs[0]
|
encoder_output = encoder_outputs[0]
|
||||||
|
|
||||||
decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range)
|
decoder_outputs = self.decoder(input_ids,
|
||||||
decoder_outputs = self.decoder(decoder_input,
|
|
||||||
encoder_output,
|
encoder_output,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user