mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
merge the two Bert layers classes
This commit is contained in:
parent
edfc8f8225
commit
9ca788b2e8
@ -318,15 +318,26 @@ class BertOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertEncoderLayer(nn.Module):
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoderLayer, self).__init__()
|
||||
super(BertLayer, self).__init__()
|
||||
self.self_attention = BertAttention(config)
|
||||
if config.get('is_decoder', False):
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
||||
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
|
||||
self_attention_output = self_attention_outputs[0]
|
||||
|
||||
attention_outputs = self_attention_outputs
|
||||
if encoder_hidden_state:
|
||||
try:
|
||||
attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||
except AttributeError as ae:
|
||||
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
@ -334,35 +345,12 @@ class BertEncoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class BertDecoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoderLayer, self).__init__()
|
||||
self.self_attention = BertAttention(config)
|
||||
self.attention = BertDecoderAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
|
||||
self_attention_output = self_attention_outputs[0]
|
||||
attention_outputs = self.attention(query=self_attention_output,
|
||||
key=encoder_outputs,
|
||||
value=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
@ -392,9 +380,10 @@ class BertEncoder(nn.Module):
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
config["is_decoder"] = True
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layers = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
@ -403,7 +392,10 @@ class BertDecoder(nn.Module):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
||||
layer_outputs = layer_module(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_state=encoder_outputs)
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user