diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 3187d1ca506..20b49c592f3 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -315,9 +315,9 @@ class BertOutput(nn.Module): return hidden_states -class BertLayer(nn.Module): +class BertEncoderLayer(nn.Module): def __init__(self, config): - super(BertLayer, self).__init__() + super(BertEncoderLayer, self).__init__() self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) @@ -336,7 +336,7 @@ class BertEncoder(nn.Module): super(BertEncoder, self).__init__() self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states - self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask=None, head_mask=None): all_hidden_states = ()