merge the two Bert layers classes

This commit is contained in:
Rémi Louf 2019-10-10 11:33:28 +02:00
parent edfc8f8225
commit 9ca788b2e8

View File

@ -318,15 +318,26 @@ class BertOutput(nn.Module):
return hidden_states
class BertEncoderLayer(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config):
super(BertEncoderLayer, self).__init__()
super(BertLayer, self).__init__()
self.self_attention = BertAttention(config)
if config.get('is_decoder', False):
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
self_attention_output = self_attention_outputs[0]
attention_outputs = self_attention_outputs
if encoder_hidden_state:
try:
attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state)
except AttributeError as ae:
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
@ -334,35 +345,12 @@ class BertEncoderLayer(nn.Module):
return outputs
class BertDecoderLayer(nn.Module):
def __init__(self, config):
super(BertDecoderLayer, self).__init__()
self.self_attention = BertAttention(config)
self.attention = BertDecoderAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
self_attention_output = self_attention_outputs[0]
attention_outputs = self.attention(query=self_attention_output,
key=encoder_outputs,
value=encoder_outputs,
attention_mask=attention_mask,
head_mask=head_mask)
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:]
return outputs
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None):
all_hidden_states = ()
@ -392,9 +380,10 @@ class BertEncoder(nn.Module):
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
config["is_decoder"] = True
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layers = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
all_hidden_states = ()
@ -403,7 +392,10 @@ class BertDecoder(nn.Module):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
layer_outputs = layer_module(hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_state=encoder_outputs)
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)