mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
merge the two Bert layers classes
This commit is contained in:
parent
edfc8f8225
commit
9ca788b2e8
@ -318,15 +318,26 @@ class BertOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BertEncoderLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertEncoderLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
|
self.self_attention = BertAttention(config)
|
||||||
|
if config.get('is_decoder', False):
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
||||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
|
||||||
|
self_attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
attention_outputs = self_attention_outputs
|
||||||
|
if encoder_hidden_state:
|
||||||
|
try:
|
||||||
|
attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||||
|
except AttributeError as ae:
|
||||||
|
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
|
||||||
|
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
@ -334,35 +345,12 @@ class BertEncoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class BertDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super(BertDecoderLayer, self).__init__()
|
|
||||||
self.self_attention = BertAttention(config)
|
|
||||||
self.attention = BertDecoderAttention(config)
|
|
||||||
self.intermediate = BertIntermediate(config)
|
|
||||||
self.output = BertOutput(config)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
|
||||||
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
|
|
||||||
self_attention_output = self_attention_outputs[0]
|
|
||||||
attention_outputs = self.attention(query=self_attention_output,
|
|
||||||
key=encoder_outputs,
|
|
||||||
value=encoder_outputs,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
head_mask=head_mask)
|
|
||||||
attention_output = attention_outputs[0]
|
|
||||||
intermediate_output = self.intermediate(attention_output)
|
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
|
||||||
outputs = (layer_output,) + attention_outputs[1:]
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
@ -392,9 +380,10 @@ class BertEncoder(nn.Module):
|
|||||||
class BertDecoder(nn.Module):
|
class BertDecoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertDecoder, self).__init__()
|
super(BertDecoder, self).__init__()
|
||||||
|
config["is_decoder"] = True
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layers = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
@ -403,7 +392,10 @@ class BertDecoder(nn.Module):
|
|||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
layer_outputs = layer_module(hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_state=encoder_outputs)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user