rename BertLayer to BertEncoderLayer

This commit is contained in:
Rémi Louf 2019-10-07 16:31:46 +02:00
parent 0053c0e052
commit dda1adad6d

View File

@ -315,9 +315,9 @@ class BertOutput(nn.Module):
return hidden_states
class BertLayer(nn.Module):
class BertEncoderLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
super(BertEncoderLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
@ -336,7 +336,7 @@ class BertEncoder(nn.Module):
super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None):
all_hidden_states = ()