mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Revert "add output_attentions for BertModel"
This reverts commit de5e5682a1
.
This commit is contained in:
parent
de5e5682a1
commit
826496580b
@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
@ -291,8 +291,6 @@ class BertSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
@ -324,10 +322,7 @@ class BertSelfAttention(nn.Module):
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
if self.output_attentions:
|
||||
return attention_probs, context_layer
|
||||
else:
|
||||
return context_layer
|
||||
return context_layer
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
@ -386,43 +381,33 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
if self.output_attentions:
|
||||
return attention_output, layer_output
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
layer = BertLayer(config, output_attentions=output_attentions)
|
||||
layer = BertLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
all_encoder_layers = []
|
||||
all_attentions = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
if self.output_attentions:
|
||||
attentions, hidden_states = hidden_states
|
||||
all_attentions.append(attentions)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
if self.output_attentions:
|
||||
return all_attentions, all_encoder_layers
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
@ -714,13 +699,12 @@ class BertModel(BertPreTrainedModel):
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config):
|
||||
super(BertModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config, output_attentions=output_attentions)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
self.output_attentions = output_attentions
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
||||
if attention_mask is None:
|
||||
@ -747,14 +731,10 @@ class BertModel(BertPreTrainedModel):
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
output_all_encoded_layers=output_all_encoded_layers)
|
||||
if self.output_attentions:
|
||||
all_attentions, encoded_layers = encoded_layers
|
||||
sequence_output = encoded_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
if not output_all_encoded_layers:
|
||||
encoded_layers = encoded_layers[-1]
|
||||
if self.output_attentions:
|
||||
return all_attentions, encoded_layers, pooled_output
|
||||
return encoded_layers, pooled_output
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user